Classic data set-handwritten digit recognition pytorch

pytorch classic data set-handwritten digit recognition

1. What is MNIST?

MNIST is the most basic data set in the field of computer vision and is also the first neural network model for many people.

The MNIST data set (Mixed National Institute of Standards and Technology database) is a large handwritten digit data set collected by the National Institute of Standards and Technology. It contains a training set of 60,000 samples and a test set of 10,000 samples.

All samples in MNIST will convert the original 28*28 grayscale image into a one-dimensional vector with a length of 784 as input, where each element corresponds to the grayscale value in the grayscale image. MNIST uses a one-hot vector of length 10 as the label corresponding to the sample, where the vector index value corresponds to the predicted probability that the sample will result in that index.

2. Detailed code introduction

The main purpose of MNIST handwritten digit recognition is to train a model so that the model can classify pictures of handwritten digits.

First understand the steps and process, and then start building the network structure and training the model.

Import the libraries to be used

Utils is an external file, with several functions defined by yourself. The detailed code is at the end of the article.

#导入需要的各种库
import torch
#神经网络
from torch import nn
#function神经网络中常见的函数
from torch.nn import functional as F
#梯度下降优化包
from torch import optim
#图形视觉包
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

Load dataset

#1 加载数据集
#load dataset
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)#shuffle打乱

#预览训练集数据
x, y=next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())

#画图,图片识别,识别结果
plot_image(x,y,'image_sample')

Use the Net model to create a three-layer network structure + add a relu activation function layer


#2 创建网络
#制作三层线性网络层 + relu函数 网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

    #三层线性 xw +b
    #第一层 28*28 =》打平成一个向量 输出是中间层,一般取2^n,逐步减小
        #Linear(输入,输出)
        self.fc1 = nn.Linear(28*28,256)
        #第二层 上一层输出是这一层的输入
        self.fc2 = nn.Linear(256,64)
        #第三层 是最终的输出=== 分类数有关
        self.fc3 = nn.Linear(64,10)

    def forward(self,x):
        # x[512,1,28,28] 输入层结构:512张灰度图片,28*28
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        #x = F.relu(self.fc3(x))
        #一般来说,最后一层激活函数可加可不加
        x = self.fc3(x)
        return x

network training

#3 网络训练
#迭代的次数,对数据集迭代3次
for epoch in range(3):
    #每次迭代,对数据集每512张做训练
    for batch_idx, (x,y) in enumerate(train_loader):
        # x[512,1,28,28] 28*28===1*784 打平矩阵,维度转换
        x = x.view(x.size(0),28*28)#一维 1*784
        # 放入网络训练
        #out:[512,10]
        out = net(x)
        #label用onthot编码转化成向量
        y_onehot = one_hot(y)
        #计算loss 欧式距离
        loss = F.mse_loss(out,y_onehot)
        #梯度下降
        #梯度清零
        optimizer.zero_grad()
        #计算梯度
        loss.backward()
        #更新梯度 w' = w - lr * grad
        optimizer.step()
        #此时退出循环,得到了最好的结果【w1,w2,w3,b1,b2,b3】
        if batch_idx % 10 == 0:
            #每10次打印loss
            print(epoch,batch_idx,loss.item())

Verification Test

#4 验证
total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0),28*28)
    out = net(x)#[512,x]
    pred = out.argmax(dim =1)#dim维度
    #pred =? 相等的数量有几张 eq()相等记为1,不相等记为0
    correct =pred.eq(y).sum().float().item()
    total_correct+=correct

total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:",acc)

x,y =next(iter(test_loader))
out =net(x.view(x.size(0),28*28))
pred = out.argmax(dim =1)
plot_image(x,pred,'test')

All code

If you want to write it as a script, just copy the following function part to the same py file. There is no need to create an extra py file. In order to make the code easier to maintain and debug, it is recommended to separate it.

import torch
from matplotlib import pyplot as plt#绘图库
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()

The mnist.py file sharing address needs to be obtained by yourself

Link: https://pan.baidu.com/s/1psjbAH5wxtaAyQpRXArr6g?pwd=y88a Extraction code: y88a Copy this content and open the Baidu Netdisk mobile app for more convenient operation.

Please correct me if there are any errors

Guess you like

Origin blog.csdn.net/m0_64892604/article/details/128882879