Deep Learning: MNIST Handwritten Digit Recognition Using Fully Connected Neural Network FCN

1 Introduction

This project builds a fully connected neural network (FCN) to realize the recognition of handwritten digits in the MINST dataset, without using any deep learning algorithm library, and understands the whole process of handwritten digit recognition in principle, including backpropagation, gradient descent, etc. .

2 Introduction to Fully Connected Neural Networks

2.1 What is a fully connected neural network

Fully-Connected Network (FCN for short), that is, in a multi-layer neural network, neach neuron in the Nth layer is connected to the neuron in the N-1th layer. The following figure is a simple fully connected network:

2.2 Loss function

In the field of deep learning, the loss function is used to calculate the error between the output value predicted by the model and the real value. It is an algorithm to measure the degree of agreement between the model and the data. The higher the value of the loss function, the more wrong the prediction is, and the lower the value of the loss function, the closer the prediction is to the true value. The loss function is computed for each individual observation (data point). The function that averages the values ​​of all loss functions is called the cost function. A simpler understanding is that the loss function is for a single sample, while the cost function is for all samples.

  • The smaller the loss function, the better
  • Calculate the difference between the actual output and the target
  • Provide basis for updating output (backpropagation)

Common Loss Functions

(1) Mean Squared Error Loss (Mean Squared Error, MSE)

The mean square error loss MSE, also known as L2 Loss, is used to calculate the mean square error of the difference between the model output y_hat and the target value y. Generally used in linear regression, it can be understood as the least squares method. Mean square error loss is the most commonly used loss function in machine learning and deep learning regression tasks .

(2) Mean Absolute Error (Mean Absolute Error, MAE)

Mean Absolute Error MAE, also known as L1 Loss, is another loss function used in regression models. Like MSE, this metric also measures the size of the error without considering the direction (if it were, it would be called the Mean Bias Error (MBE), which is the sum of residuals or errors). But unlike MSE, MAE requires more complex tools like linear programming to compute gradients. Also, MAE is more robust to outliers because it does not use squaring. The loss range is also 0 to ∞.

(3) Cross Entropy Loss Function (Cross Entropy Loss)

Cross Entropy (Cross Entropy) is an important concept in Shannon information theory, which is mainly used to measure the difference information between two probability distributions. The performance of language models is usually measured by cross entropy and complexity (perplexity). The meaning of cross-entropy is the difficulty of using the model to recognize text, or from a compression point of view, how many bits are used to encode each word on average. The Cross Entropy loss function is the most common loss function in classification problems.

2.3 Backpropagation

The emergence of the error back-propagation (Back-propagation, BP) algorithm is a major breakthrough in the development of neural networks, and it is also the basis of many deep learning training methods. This method will calculate the gradient of the loss function to each parameter in the neural network, update the parameters with the optimization method, and reduce the loss function. BP originally only refers to the process of the gradient of the loss function to the parameter flowing backward through the network, but now it is often understood as the entire training method of the neural network, which is composed of two links of error propagation and parameter update.

During the training process of the neural network, forward propagation and backpropagation are carried out alternately. The forward propagation calculates the output result through the training data and weight parameters; the backpropagation calculates the gradient of the loss function to each parameter through the derivative chain rule, and according to the gradient update the parameters

 

3 Using FCN to realize MNIST handwritten digit recognition

3.1 Introduction to MINIST dataset

The MNIST dataset is a large database of handwritten digits collected by the National Institute of Standards and Technology, including a training set of 60,000 examples and a test set of 10,000 examples. The image size is 28*28. The sampled data is displayed as follows:

3.2 FCN recognition MINIST data set code implementation

import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as np

class MnistNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            # 图片的原尺寸为28*28,转化为784,输入层为784,输出层为256
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 10),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = x.view(-1, 28*28*1)
        return self.layer(x)


batchsize = 32
lr = 0.01

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307, ), (0.3081,))])

data_train = datasets.MNIST(root="./data/", transform=transform, train=True, download=True)
data_test = datasets.MNIST(root="./data/", transform=transform, train=False)

train_loader = torch.utils.data.DataLoader(data_train, batch_size=batchsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(data_test, batch_size=batchsize, shuffle=False)

if __name__ == '__main__':
    model = MnistNet()

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.5)

    for i in range(5):
        plt.subplot(1, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.imshow(data_train.data[i], cmap=plt.cm.binary)
    plt.show()

    lepoch = []
    llost = []
    lacc = []
    epochs = 30
    for epoch in range(epochs):
        lost = 0
        count = 0
        for num, (x, y) in enumerate(train_loader, 1):
            y_h = model(x)
            loss = criterion(y_h, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lost += loss.item()
            count += batchsize
        print('epoch:', epoch + 1, 'loss:', lost / count, end=' ')
        lepoch.append(epoch + 1)
        llost.append(lost / count)

        with torch.no_grad():
            acc = 0
            count = 0
            for num, (x, y) in enumerate(test_loader, 1):
                y_h = model(x)
                _, y_h = torch.max(y_h.data, dim=1)
                acc += (y_h == y).sum().item()
                count += x.size(0)
            test_acc = acc / count * 100

        lacc.append(test_acc)
        print('acc:', test_acc)

    plt.plot(lepoch, llost, label='loss')
    plt.plot(lepoch, lacc, label='acc')
    plt.legend()
    plt.show()

3.3 Result output

After 30 epochs, the accuracy rate on the test set reached 97.3%

epoch: 1 loss: 0.0697015597740809 acc: 56.120000000000005
epoch: 2 loss: 0.0542279725531737 acc: 81.2
epoch: 3 loss: 0.051337766939401626 acc: 83.53
epoch: 4 loss: 0.05083678769866626 acc: 84.49
epoch: 5 loss: 0.05052243163983027 acc: 85.09
epoch: 6 loss: 0.05029139596422513 acc: 85.65
epoch: 7 loss: 0.050102355525890985 acc: 86.14
epoch: 8 loss: 0.04994755889574687 acc: 86.02
epoch: 9 loss: 0.0498184863169988 acc: 86.71
epoch: 10 loss: 0.04970114469528198 acc: 86.81
epoch: 11 loss: 0.04792855019172033 acc: 94.86
epoch: 12 loss: 0.047099880089362466 acc: 95.64
epoch: 13 loss: 0.04690476657748222 acc: 96.04
epoch: 14 loss: 0.04677621142864227 acc: 96.32
epoch: 15 loss: 0.046683601369460426 acc: 96.52
epoch: 16 loss: 0.04659009942809741 acc: 96.69
epoch: 17 loss: 0.04652327968676885 acc: 96.72
epoch: 18 loss: 0.04646410925189654 acc: 96.81
epoch: 19 loss: 0.0464125766257445 acc: 96.75
epoch: 20 loss: 0.04636456128358841 acc: 97.07000000000001
epoch: 21 loss: 0.046326734560728076 acc: 96.85000000000001
epoch: 22 loss: 0.04628034559885661 acc: 96.91
epoch: 23 loss: 0.04625135076443354 acc: 97.0
epoch: 24 loss: 0.046217381453514096 acc: 97.14
epoch: 25 loss: 0.046193461724122364 acc: 97.03
epoch: 26 loss: 0.046168098962306975 acc: 97.16
epoch: 27 loss: 0.0461397964378198 acc: 97.27
epoch: 28 loss: 0.0461252645790577 acc: 97.22
epoch: 29 loss: 0.04609716224273046 acc: 97.19
epoch: 30 loss: 0.04608173056840897 acc: 97.3

The accuracy change curve is as follows:

Guess you like

Origin blog.csdn.net/lsb2002/article/details/132047784