Pytorch de reconocimiento de dígitos escritos a mano de conjunto de datos clásico

conjunto de datos clásico de pytorch: reconocimiento de dígitos escritos a mano

1. ¿Qué es MNIST?

MNIST es el conjunto de datos más básico en el campo de la visión por computadora y también es el primer modelo de red neuronal para muchas personas.

El conjunto de datos MNIST (base de datos del Instituto Nacional Mixto de Estándares y Tecnología) es un gran conjunto de datos de dígitos escritos a mano recopilados por el Instituto Nacional de Estándares y Tecnología, que contiene un conjunto de entrenamiento de 60,000 muestras y un conjunto de prueba de 10,000 muestras.

Todas las muestras en MNIST convertirán la imagen original en escala de grises de 28 * 28 en un vector unidimensional con una longitud de 784 como entrada, donde cada elemento corresponde al valor de escala de grises en la imagen en escala de grises. MNIST utiliza un vector one-hot de longitud 10 como etiqueta correspondiente a la muestra, donde el valor del índice del vector corresponde a la probabilidad predicha de que la muestra dará como resultado ese índice.

2. Introducción detallada del código

El objetivo principal del reconocimiento de dígitos escritos a mano de MNIST es entrenar un modelo para que pueda clasificar imágenes de dígitos escritos a mano.

Primero comprenda los pasos y el proceso, y luego comience a construir la estructura de la red y a entrenar el modelo.

Importar las bibliotecas que se utilizarán.

Utils es un archivo externo, con varias funciones definidas por usted mismo, el código detallado se encuentra al final del artículo.

#导入需要的各种库
import torch
#神经网络
from torch import nn
#function神经网络中常见的函数
from torch.nn import functional as F
#梯度下降优化包
from torch import optim
#图形视觉包
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

Cargar conjunto de datos

#1 加载数据集
#load dataset
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)#shuffle打乱

#预览训练集数据
x, y=next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())

#画图,图片识别,识别结果
plot_image(x,y,'image_sample')

Utilice el modelo Net para crear una estructura de red de tres capas + agregue una capa de función de activación relu


#2 创建网络
#制作三层线性网络层 + relu函数 网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

    #三层线性 xw +b
    #第一层 28*28 =》打平成一个向量 输出是中间层,一般取2^n,逐步减小
        #Linear(输入,输出)
        self.fc1 = nn.Linear(28*28,256)
        #第二层 上一层输出是这一层的输入
        self.fc2 = nn.Linear(256,64)
        #第三层 是最终的输出=== 分类数有关
        self.fc3 = nn.Linear(64,10)

    def forward(self,x):
        # x[512,1,28,28] 输入层结构:512张灰度图片,28*28
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        #x = F.relu(self.fc3(x))
        #一般来说,最后一层激活函数可加可不加
        x = self.fc3(x)
        return x

entrenamiento en red

#3 网络训练
#迭代的次数,对数据集迭代3次
for epoch in range(3):
    #每次迭代,对数据集每512张做训练
    for batch_idx, (x,y) in enumerate(train_loader):
        # x[512,1,28,28] 28*28===1*784 打平矩阵,维度转换
        x = x.view(x.size(0),28*28)#一维 1*784
        # 放入网络训练
        #out:[512,10]
        out = net(x)
        #label用onthot编码转化成向量
        y_onehot = one_hot(y)
        #计算loss 欧式距离
        loss = F.mse_loss(out,y_onehot)
        #梯度下降
        #梯度清零
        optimizer.zero_grad()
        #计算梯度
        loss.backward()
        #更新梯度 w' = w - lr * grad
        optimizer.step()
        #此时退出循环,得到了最好的结果【w1,w2,w3,b1,b2,b3】
        if batch_idx % 10 == 0:
            #每10次打印loss
            print(epoch,batch_idx,loss.item())

Prueba de verificación

#4 验证
total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0),28*28)
    out = net(x)#[512,x]
    pred = out.argmax(dim =1)#dim维度
    #pred =? 相等的数量有几张 eq()相等记为1,不相等记为0
    correct =pred.eq(y).sum().float().item()
    total_correct+=correct

total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:",acc)

x,y =next(iter(test_loader))
out =net(x.view(x.size(0),28*28))
pred = out.argmax(dim =1)
plot_image(x,pred,'test')

Todo el código

Si desea escribirlo como un script, simplemente copie la siguiente parte de la función en el mismo archivo py. No es necesario crear un archivo py adicional. Para que el código sea más fácil de mantener y depurar, se recomienda separarlo. él.

import torch
from matplotlib import pyplot as plt#绘图库
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()

La dirección para compartir archivos mnist.py debe obtenerla usted mismo

Enlace: https://pan.baidu.com/s/1psjbAH5wxtaAyQpRXArr6g?pwd=y88a Código de extracción: y88a Copie este contenido y abra la aplicación móvil Baidu Netdisk para una operación más conveniente.

Por favor corríjanme si hay algún error.

Supongo que te gusta

Origin blog.csdn.net/m0_64892604/article/details/128882879
Recomendado
Clasificación