pytorch实现mnist手写识别

跟着龙龙老师的视频来敲的,然后有几行来自 https://www.cnblogs.com/liualexsone/p/11355217.html 这位朋友的教程,写的很好很好,谢谢

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

from pytorch__lesson.pytorch_mnist.utils import plot_curve,plot_image,one_hot
import matplotlib.pyplot as plt

# step1 load dataset
batch_size=512

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=False)
# x,y=next(iter(train_loader))
# print(x.shape,y.shape,x.min(),x.max())
# plot_image(x,y,'image sample')

# 完成网络的创建
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1=nn.Linear(784,256)
        self.fc2=nn.Linear(256,64)
        self.fc3=nn.Linear(64,10)
        self.relu=nn.ReLU()

    def forward(self, x):
        x=self.fc1(x)
        x=self.relu(x)
        x=self.fc2(x)
        x=self.relu(x)
        x=self.fc3(x)
        return x


net=Net()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.99)
losses_func=nn.MSELoss()
correct=0
total=0
# 损失函数使用前要先进行生命
for epoch in range(5):
    for batch_index,(x,y) in enumerate(train_loader):
        x=x.view(x.size(0),28*28)
        output=net(x)
        onehot_y=one_hot(y)
        loss=losses_func(output,onehot_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_index%10==0:
            # print(epoch,batch_index,loss.item())
            _, predicted = torch.max(output, 1)
            total+=y.size(0)
            correct+=(predicted == y).sum().item()
            print("The accuracy of total {} images: {}%".format(total, 100 * correct / total))

然后utils.py在这里

import torch
from matplotlib import pyplot as plt

def one_hot(label,depth=10):
    out=torch.zeros()

def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()



def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

# 生成独热码的函数
def one_hot(label,depth=10):
    # 第0维代表的是batch_size
    out=torch.zeros(label.size(0),depth)
    idx=torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1,index=idx,value=1)
    return out

独热码的做法我在上一篇的博文介绍了

猜你喜欢

转载自www.cnblogs.com/daremosiranaihana/p/12541788.html