[Pytorch Framework] 1.6 Entrenamiento de un clasificador

Entrenar a un clasificador

En la última lección, hemos visto cómo definir una red neuronal, calcular el valor de pérdida y actualizar el peso de la red.
Es posible que ahora esté pensando en el próximo paso.

¿Sobre los datos?

Por lo general, al procesar datos de imagen, texto, audio y video, puede usar paquetes estándar de Python para cargar los datos en una matriz numerosa.
Luego convierta esta matriz en torch.*Tensor.

  • Las imágenes pueden usar Pillow, OpenCV
  • El audio puede usar scipy, librosa
  • El texto se puede cargar usando Python y Cython sin
    procesar , o procesarse usando NLTK o SpaCy

En particular, para las tareas de imágenes, hemos creado un paquete
torchvisionque contiene métodos para procesar algunos conjuntos de datos de imágenes básicos. Estos conjuntos de datos incluyen
Imagenet, CIFAR10, MNIST, etc. Además de la carga de datos, torchvisiontambién contiene convertidores de imágenes
torchvision.datasetsy torch.utils.data.DataLoader.

torchvisionEl paquete no solo proporciona una gran comodidad, sino que también evita la duplicación de código.

En este tutorial, usamos el conjunto de datos CIFAR10, que tiene las siguientes 10 categorías
: 'avión', 'automóvil', 'pájaro', 'gato', 'ciervo',
'perro', 'rana', 'caballo', ' barco ',' camión '. Las imágenes CIFAR-10 tienen
un tamaño de 3x32x32, es decir, 3 canales de color, 32x32 píxeles.

Entrena un clasificador de imágenes

Proceda sucesivamente en el siguiente orden:

  1. Utilice la torchvisioncarga y normalización del conjunto de entrenamiento y de prueba CIFAR10

  2. Definir una red neuronal convolucional

  3. Definir función de pérdida

  4. Entrene a la red en el conjunto de entrenamiento

  5. Pruebe la red en el equipo de prueba

  6. Leer y normalizar CIFAR10


El uso torchvisionse puede cargar fácilmente CIFAR10.

import torch
import torchvision
import torchvision.transforms as transforms

La salida de torchvision es una imagen PILImage de [0,1], y la convertimos a un tensor con un rango normalizado de [-1, 1].

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|███████████████████████████████████████████████████████████████████████████████▉| 170M/170M [20:39<00:00, 155kB/s]

Files already downloaded and verified

Mostramos algunas imágenes de entrenamiento.

import matplotlib.pyplot as plt
import numpy as np

# 展示图像的函数


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# 获取随机数据
dataiter = iter(trainloader)
images, labels = dataiter.next()

# 展示图像
imshow(torchvision.utils.make_grid(images))
# 显示图像标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
171MB [20:51, 155kB/s]                                                                                                 

  cat   car   cat  ship
  1. Definir una red neuronal convolucional

Copie el código de la red neuronal de la sección anterior de la red neuronal y modifíquelo para ingresar una imagen de 3 canales.

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
  1. Definir la función de pérdida y el optimizador

Usamos la entropía cruzada como función de pérdida y usamos el descenso de gradiente estocástico con impulso.

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  1. Red de formación

Comenzó el momento divertido.
Solo necesitamos hacer un bucle en el iterador de datos, alimentar los datos a la red y optimizar.

for epoch in range(2):  # 多批次循环

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 获取输入
        inputs, labels = data

        # 梯度置0
        optimizer.zero_grad()

        # 正向传播,反向传播,优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印状态信息
        running_loss += loss.item()
        if i % 2000 == 1999:    # 每2000批次打印一次
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')
  1. Pruebe la red en el equipo de prueba

Realizamos 2 capacitaciones sobre todo el conjunto de capacitación, pero debemos verificar si la red ha aprendido algo útil del conjunto de datos.
La detección se realiza comparando las etiquetas de categoría generadas por la red neuronal predictiva con las etiquetas de situación real.
Si la predicción es correcta, agregamos la muestra a la lista de predicciones correctas.
El primer paso es mostrar las imágenes en el conjunto de prueba y familiarizarse con el contenido de las imágenes.

dataiter = iter(testloader)
images, labels = dataiter.next()

# 显示图片
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
GroundTruth:    cat  ship  ship plane

[Error en la transferencia de la imagen del enlace externo. El sitio de origen puede tener un mecanismo de enlace anti-sanguijuela. Se recomienda guardar la imagen y subirla directamente (img-7N9Ui0NR-1612164760383) (output_14_1.png)]

Veamos qué piensa la red neuronal que es la imagen de arriba.

outputs = net(images)

La salida es la energía de 10 etiquetas.
Cuanto mayor es la energía de una categoría, más la red neuronal considera que es esa categoría. Así que consigamos la etiqueta energética más alta.

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))
Predicted:  plane plane plane plane

El resultado luce bien.

Echemos un vistazo a los resultados de la red en todo el conjunto de pruebas.

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))
Accuracy of the network on the 10000 test images: 9 %

El resultado se ve bien, al menos mejor que la selección aleatoria, que es correcta en un 10%.
Parece que Internet ha aprendido algo.

¿Qué categoría es buena y cuál no?

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane : 99 %
Accuracy of   car :  0 %
Accuracy of  bird :  0 %
Accuracy of   cat :  0 %
Accuracy of  deer :  0 %
Accuracy of   dog :  0 %
Accuracy of  frog :  0 %
Accuracy of horse :  0 %
Accuracy of  ship :  0 %
Accuracy of truck :  0 %

¿Próximo paso?

¿Cómo ejecutamos redes neuronales en la GPU?

Entrenar en GPU

Mover una red neuronal a la GPU para entrenar es tan simple como convertir un tensor en la GPU. Y esta operación atravesará recursivamente algunos módulos y convertirá sus parámetros y búferes en tensores CUDA.

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 确认我们的电脑支持CUDA,然后显示CUDA信息:

print(device)

El resto de esta sección asume deviceun dispositivo CUDA.

Luego, estos métodos atravesarán recursivamente todos los módulos y
convertirán los parámetros y búferes del módulo en tensores CUDA:

    net.to(device)

Recuerde: las entradas, los objetivos y las imágenes también deben convertirse.

        inputs, labels = inputs.to(device), labels.to(device)

¿Por qué no nos hemos dado cuenta de que la velocidad de la GPU ha aumentado tanto? Eso es porque la red es muy pequeña.

Práctica:
Intente aumentar el ancho de su red (el nn.Conv2dsegundo parámetro del primero, nn.Conv2del primer parámetro del segundo , deben ser el mismo número) y vea qué aceleración obtiene.

Objetivos alcanzados :

  • Comprensión profunda de la biblioteca de tensores y la red neuronal de PyTorch
  • Entrenó una pequeña red para clasificar imágenes.

Nota del traductor: Más adelante en nuestro tutorial, entrenaremos una red real para lograr una tasa de reconocimiento de más del 90%.

Entrenamiento multi-GPU

Si desea utilizar todas las GPU para obtener una mayor velocidad,
consulte Procesamiento paralelo de datos .

¿Próximo paso?

  • :Doc:训练神经网络玩电子游戏 </intermediate/reinforcement_q_learning>
  • 在ImageNet上训练最好的ResNet
  • 使用对抗生成网络来训练一个人脸生成器
  • 使用LSTM网络训练一个字符级的语言模型
  • 更多示例
  • 更多教程
  • 在论坛上讨论PyTorch
  • Slack上与其他用户讨论

Supongo que te gusta

Origin blog.csdn.net/yegeli/article/details/113520680
Recomendado
Clasificación