PyTorch-based LSTM long short-term memory network to realize MNIST handwritten digits

This blog mainly introduces the recognition of handwritten digits based on LSTM under the PyTorch framework.

Before introducing the LSTM long short-term memory network, let me introduce the RNN (recurrent neural network) recurrent neural network.

RNN is a type of neural network used to process sequence data, including the speech we speak, a piece of text, and so on. It appears to allow the network to have its own memory ability. Each network module transmits information to the next module. Its network structure is as follows:


For a piece of input sequence data (X1, X2, X3, X4...), give the predicted result (Y1, Y2, Y3, Y4...), if it is a text sentiment classification problem, then wipe the first few Y values Go, the last Y is left, which is the predicted sentiment classification result of a piece of text.

When the gradient descent method is used to update the parameters, the RNN will have the problem of gradient disappearance or gradient descent. When the weight W of each layer is less than 1, then the error is transmitted to the beginning, the result is close to 0, and the gradient disappears; when the weight W of each layer is greater than 1, the error is transmitted to the first layer, the result will become infinite, and the gradient will disappear. explode.

To avoid this problem, we introduce the LSTM long short-term memory network, which is mainly used to delay memory decay.

The network structure of LSTM is as follows:


It also transmits the output Y of the previous moment to the next moment, and the intermediate unit adds a function to judge whether the information is useful, so that the network can choose to memorize useful information and forget useless information. It is mainly implemented through three gates - input gate, forget gate and output gate.

When recognizing MNIST handwritten digits, we can input a picture of 28*28 pixels in rows, each row corresponds to a moment, so there are 28 X inputs and 1 Y output, and the pixel value of each row is analogous to a sequence of data.

The following code tests the accuracy on the test data and the results of taking the first 10 test data.

import torch
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as dsets
import torch.utils.data as Data
import matplotlib.pyplot as plt
import torchvision

torch.manual_seed(1)

EPOCH = 1
BATCH_SIZE = 64
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01
DOWNLOAD_MNIST = False

train_data = dsets.MNIST(
    root = './mnist',
    train = True,
    transform = torchvision.transforms.ToTensor(),
    download = DOWNLOAD_MNIST,
)

test_data = torchvision.datasets.MNIST(root='./mnist',train=False)

train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)

test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1),volatile=True).type(torch.FloatTensor)/255
test_y = test_data.test_labels

class RNN(nn.Module):
    def __init__(self):
        super (RNN, self) .__ init __ ()

        self.rnn = nn.LSTM (
            input_size=28,
            hidden_size=64,
            num_layers=1,
            batch_first=True,
        )

        self.out = nn.Linear(64,10)

    def forward(self,x):
        r_out, (h_n, h_c) = self.rnn(x, None)

        out = self.out(r_out[:,-1,:])
        return out

rnn = RNN ()
print (rnn)

optimizer = torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()

for epoch in range(EPOCH):
    for step,(x,y) in enumerate(train_loader):
        b_x = Variable(x.view(-1,28,28))
        b_y = Variable(y)

        output = rnn(b_x)
        loss = loss_func(output,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step%50 == 0:
            test_output = rnn(test_x.view(-1,28,28))
            pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
            accuracy = sum(pred_y == test_y)/float(test_y.size(0))
            print('Epoch: ',epoch, '| train loss:%.4f' %loss.data[0],'| test accuracy:%.2f' %accuracy)

test_output = rnn(test_x[:10].view(-1,28,28))
pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
print(pred_y,'prediction number')
print(test_y[:10].numpy(),'real number')

operation result:


Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=326068896&siteId=291194637