基于pytorch实现手写数字识别(附python代码)

/1加载图片:加载数据集,没有的话会自动下载,数据分布在0附近,并打散。

训练集:测试集=6k:1k。

utils.py文件:plot_image()绘制loss下降曲线; plot_curve()显示图片通过plot_image()可视化结果。minst_train.py文件:读取Minst数据集

/2 加载模型:三层线性模型,前两层用ReLU函数,batch_size=512,一张图片28*28,Normalize将数据均匀分布。

/3 训练:学习率0.01,momentum = 0.9,loss定义,梯度清零、计算、更新,每10次显示loss,可以看到loss下降:

/4 测试

计算正确率并显示梯度下降:

遇到的问题:pytorch中优化器获得的是空参数表

ValueError:optimizer got an empty parameter list

解决:初始函数定义未正确,两个下划线

def __init__(self):

        super(Net, self).__init__()

win10+anaconda3+python3.7,安装tensorflow、pytorch、opencv、CUDA10.2

mnist_train.py

# -*- coding: utf-8 -*-
"""
Created on Tue Jan 14 15:10:20 2020

@author: ZM
"""
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision
from matplotlib import pyplot as plt

from utils import plot_image, plot_curve, one_hot

batch_size=512
#step1:load dataset
#加载数据集,没有的话会自动下载,数据分布在0附近,并打散
train_loader=torch.utils.data.DataLoader(
   torchvision.datasets.MNIST('mnist_data',train=True,download=True,
                               transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                               (0.1307,),(0.3081,))
                                       ])),
    batch_size=batch_size,shuffle=True)
                                       
test_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/',train=False,download=True,
                               transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                               (0.1307,),(0.3081,))
                                       ])),
    batch_size=batch_size,shuffle=False)
                                       
#显示:batch_size=512,一张图片28*28,Normalize将数据均匀
x, y = next(iter(train_loader))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x, y, 'image sample')

#建立模型
class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        
        #wx+b
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64,10)
        
    def forward(self, x):
        #x:[b,1,28,28]
        #h1=relu(w1x+b1)
        x = F.relu(self.fc1(x))
        #h2=relu(h1w2+b2)
        x = F.relu(self.fc2(x))
        #h3=h2w3+b3
        x = self.fc3(x)
        
        return x
#        return F.log_softmax(x, dim=1)
#训练    
net = Net()#初始化
#返回[w1,b1,w2,b2,w3,b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum = 0.9)
train_loss = []

for epoch in range(3):
    for batch_idx, (x,y) in enumerate(train_loader):
        
#        x[b,1,28,28] y:[512]
#        print(x.shape,y.shape)
#        break
#        x, y = Variable(x), Variable(y)
        #[b,1,28,28]=>[b,784]实际图片4维打平为二维
    
        x = x.view(x.size(0), 28*28)
        #[b,10]
        out = net(x)
        #[b,10]
        y_onehot = one_hot(y)
        #loss=mse(out,y_onehot)
        loss = F.mse_loss(out, y_onehot)
        
        optimizer.zero_grad()
        loss.backward()
        #w'=w-li*grad
        optimizer.step()
        
#测试
        train_loss.append(loss.item())
        if batch_idx % 10==0:
            print(epoch, batch_idx, loss.item())
plot_curve(train_loss)
#达到较好的[w1,b1,w2,b2,w3,b3]
            
total_correct=0
for x,y in test_loader:
    x = x.view(x.size(0),28*28)  
    #out:[b,10] => pred:[b]     
    out = net(x)
     
    pred = out.argmax(dim = 1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
     
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)

x,y = next(iter(test_loader))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim = 1)
plot_image(x, pred, 'test')

utils.py 

# -*- coding: utf-8 -*-
"""
Created on Tue Jan 14 16:37:46 2020

@author: ZM
"""

import torch
from matplotlib import pyplot as plt

def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()
    
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1, index=idx, value=1)
    return out
发布了40 篇原创文章 · 获赞 3 · 访问量 7569

猜你喜欢

转载自blog.csdn.net/OpenSceneGraph/article/details/104147092
今日推荐