introducción detallada a torchvision

Prefacio

El camino hacia el aprendizaje profundo es largo, solo resumiendo constantemente y teniendo los pies en la tierra podemos lograr el éxito, también debemos seguir animándonos, no te rindas, ¡cree que puedes hacerlo! ! ! ¡Animo mutuo! ! !

Introducción a la visión de la antorcha

torchvisionEs pytorchuna biblioteca de gráficos que sirve para PyTorchmarcos de aprendizaje profundo y se utiliza principalmente para construir modelos de visión por computadora . La siguiente es torchvisionla composición:

  1. torchvision.datasets: Algunas funciones para cargar datos e interfaces de conjuntos de datos de uso común;
  2. torchvision.models: Contiene estructuras de modelos de uso común (incluidos modelos previamente entrenados), como AlexNet, VGG, ResNet, etc.;
  3. torchvision.transforms: Transformaciones de imágenes de uso común, como recortar, rotar, etc.;
  4. torchvision.utils: Algunos otros métodos útiles.

torchvision.transforms

torchvision.transformsSe utiliza principalmente para algunas transformaciones de gráficos comunes .
torchvision.transforms.Compose()amable. La función principal de esta clase es encadenar múltiples operaciones de transformación de imágenes. La construcción de esta clase es simple:

# 图像预处理步骤
transform = transforms.Compose([
    transforms.Resize(96), # 缩放到 96 * 96 大小
    transforms.ToTensor(), # 转化为Tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])

torchvision.conjuntos de datos

torchvision.datasetsSe utiliza para la carga de datos. El equipo de PyTorch nos ha ayudado a procesar muchos conjuntos de datos de imágenes por adelantado en este paquete.

MNISTCOCO Detección
de subtítulos LSUN ImageFolder Imagenet-12 CIFAR STL10 SVHN PhotoTour








# Image processing
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
# MNIST dataset
mnist = datasets.MNIST(
    root='./data/', train=True, transform=img_transform, download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)

torchvision.modelos

torchvision.modelsNos proporciona un modelo entrenado que podemos cargar y usar directamente.

torchvision.modelsLas siguientes estructuras de modelo se incluyen en los submódulos del módulo.

AlexNet
VGG
ResNet
SqueezeNet
DenseNet
...

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

También puedes cargar un modelo previamente entrenado por otros usando pretrained=True

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)

efecto global

# 我们这里还是对MNIST进行处理,初始的MNIST是 28 * 28,我们把它处理成 96 * 96 的torch.Tensor的格式
from torchvision import transforms as transforms
import torchvision
from torch.utils.data import DataLoader
 
# 图像预处理步骤
transform = transforms.Compose([
    transforms.Resize(96), # 缩放到 96 * 96 大小
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5)) # 归一化
])
 
DOWNLOAD = True
BATCH_SIZE = 32
 
train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transform, download=DOWNLOAD)
 
 
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True)
 
print(len(train_dataset))
print(len(train_loader))

Supongo que te gusta

Origin blog.csdn.net/frighting_ing/article/details/121863387
Recomendado
Clasificación