7 steps to understand handwritten digit recognition Mnist

Hello everyone, I am Dong Dongcan.

There are many introductory projects for image recognition, of which Mnist Handwritten Digit Recognition is definitely the most popular.

The project has the advantages of small data set, simple neural network, and simple tasks, and it integrates the things that should be in the CNN network. It can be said that although the sparrow is small, it has all internal organs.

Very suitable for beginners to learn.

This article takes you through every detail of the project in the form of code walk-through.

The code download link is attached at the end of the article. You can train a neural network from scratch without a GPU.

What is Handwritten Digit Recognition

In short, it is to build a convolutional neural network, which can complete the recognition of handwritten numbers.

I write a 6 on paper with a pen, and the neural network recognizes that it is a 6, and I write an 8, and it recognizes that it is an 8. It's that simple.

The reason why the task is simple is that its labels only have 10 classifications of 0-9, which is much smaller than the 1000 classifications of resnet and other networks on ImageNet.

Although simple, there are many principles behind it, and typical CNN training and algorithms are all absent.

Together with this project, it is the famous MNIST (Mathematical Numbers In Text) dataset.

The data set contains 60,000 training images and 10,000 test images. The images are all kinds of handwritten numbers, which basically look like this.

7-step intensive code reading

After a brief understanding of the project background, I will introduce the neural network little by little in the form of code reading.

Step 1: Import the necessary libraries

# 导入NumPy数学工具箱
import numpy as np 
# 导入Pandas数据处理工具箱
import pandas as pd
# 从 Keras中导入 mnist数据集
from keras.datasets import mnist

Keras is an open source artificial neural network library, which contains many classic neural networks and data sets, including the mnist data set to be used.

Step 2: Load the dataset

(x_train, y_train), (x_test, y_test)
=  mnist.load_data() 

This command uses the mnist module that comes with keras to load the data set (load_data) and assign it to four variables.

Among them: x_train saves the image used for training, and y_train is the corresponding label. Suppose the number in the image is 1, then the label is 1.

x_test and y_test are images and labels used for verification respectively, that is, the verification set. After the neural network is trained, it can be validated using the data in the validation set.

Step 3: Data Preprocessing

One of the preprocessing contents is to change the shape of the dataset to meet the requirements of the model.

 # 导入keras.utils工具箱的类别转换工具
from tensorflow.keras.utils import to_categorical
 # 给标签增加维度,使其满足模型的需要
 # 原始标签,比如训练集标签的维度信息是[60000, 28, 28, 1]
X_train = X_train_image.reshape(60000,28,28,1)
X_test = X_test_image.reshape(10000,28,28,1)
 # 特征转换为one-hot编码
y_train = to_categorical(y_train_lable, 10)
y_test = to_categorical(y_test_lable, 10)

There are a total of 60,000 training images and 10,000 verification images in this data set, each image has a length and width of 28 pixels, and the number of channels is 1.

Then for the training set x_train, change its shape to NHWC = [60000, 28, 28, 1], and the validation set is similar.

The role of to_categorical is to convert the sample label to one-hot encoding, and the role of one-hot encoding is to better calculate the probability or score for the category.

one-hot

The reason why one-hot encoding is used is that for the 10 labels that output 0-9, the status of each label should be equal, and there is no situation where the label number 2 is greater than the number 1.

But if we directly use the original value of the label (0-9) to calculate the final result, there will be a situation where label 2 is greater than label 1.

Therefore, in most cases, you need to convert the labels to one-hot encoding, that is, one-hot encoding, so that there is no size between the labels.

In this example, the one-hot encoding of the digits 0-9 is:

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]

Each row of vectors represents a label.

Suppose [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.] represents 0 and [0., 1., 0., 0., 0. , 0., 0., 0., 0., 0.] represent 1. It can be seen that the two are orthogonally independent, and there is no problem of who is bigger than who.

Step 4: Create the neural network.

# 从 keras 中导入模型
from keras import models 
# 从 keras.layers 中导入神经网络需要的计算层
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
# 构建一个最基础的连续的模型,所谓连续,就是一层接着一层
model = models.Sequential()
# 第一层为一个卷积,卷积核大小为(3,3), 输出通道32,使用 relu 作为激活函数
model.add(Conv2D(32, (3, 3), activation='relu', 
                 input_shape=(28,28,1)))
# 第二层为一个最大池化层,池化核为(2,2)
# 最大池化的作用,是取出池化核(2,2)范围内最大的像素点代表该区域
# 可减少数据量,降低运算量。
model.add(MaxPooling2D(pool_size=(2, 2)))
# 又经过一个(3,3)的卷积,输出通道变为64,也就是提取了64个特征。
# 同样为 relu 激活函数
model.add(Conv2D(64, (3, 3), activation='relu'))
# 上面通道数增大,运算量增大,此处再加一个最大池化,降低运算
model.add(MaxPooling2D(pool_size=(2, 2)))
# dropout 随机设置一部分神经元的权值为零,在训练时用于防止过拟合
# 这里设置25%的神经元权值为零
model.add(Dropout(0.25)) 
# 将结果展平成1维的向量
model.add(Flatten())
# 增加一个全连接层,用来进一步特征融合
model.add(Dense(128, activation='relu'))
# 再设置一个dropout层,将50%的神经元权值为零,防止过拟合
# 由于一般的神经元处于关闭状态,这样也可以加速训练
model.add(Dropout(0.5)) 
# 最后添加一个全连接+softmax激活,输出10个分类,分别对应0-9 这10个数字
model.add(Dense(10, activation='softmax'))

Each line of code above is commented to explain the function of each line. Just a few lines are the whole of this handwritten digit recognition neural network.

Step 5: Training

# 编译上述构建好的神经网络模型
# 指定优化器为 rmsprop
# 制定损失函数为交叉熵损失
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
# 开始训练              
model.fit(X_train, y_train, # 指定训练特征集和训练标签集
          validation_split = 0.3, # 部分训练集数据拆分成验证集
          epochs=5, # 训练轮次为5轮
          batch_size=128) # 以128为批量进行训练

Epoch 5/5
329/329 [================================] - 15s 46ms/step - loss: 0.1054 - accuracy : 0.9718 - val_loss: 0.0681 - val_accuracy: 0.9826
The training results are as above, it can be seen that the final training accuracy reached 98.26%, which is quite high.

Step 6: Validation on Validation Set

# 在测试集上进行模型评估
score = model.evaluate(X_test, y_test) 
print('测试集预测准确率:', score[1]) # 打印测试集上的预测准确率

313/313 [===============================] - 1s 4ms/step - loss: 0.0662 - accuracy: 0.9815 test set Prediction accuracy: 0.9815000295639038

It can be seen that there is also 98% accuracy on the validation set.

Step 7: Verify a picture

# 预测验证集第一个数据
pred = model.predict(X_test[0].reshape(1, 28, 28, 1)) 
# 把one-hot码转换为数字
print(pred[0],"转换一下格式得到:",pred.argmax())
 # 导入绘图工具包
import matplotlib.pyplot as plt
# 输出这个图片
plt.imshow(X_test[0].reshape(28, 28),cmap='Greys')

Take the first image in the validation set as an example for validation.

1/1 [==============================] - 0s 17ms/step
[4.2905590e-15 2.6790809e-11 2.8249305e-09 2.3393848e-11 7.1304548e-14
1.8217797e-18 5.7493907e-19 1.0000000e+00 8.0317367e-15 4.6352322e-10]

Convert the format to get: 7

The number obtained is 7, and the picture is displayed, and it is indeed 7. It shows that the trained model has indeed reached the level of recognizing numbers.

Summarize

The handwritten digit recognition project is relatively simple, with only two convolutional layers, and the overall computational load is not large. As far as the current computer configuration is concerned, even a notebook can basically complete the training and verification of the neural network.

If you are interested, pay attention to the official account "Dong Dongcan is a siege lion" and reply [mnist] in the background to get the source code and practice it.

Guess you like

Origin blog.csdn.net/dongtuoc/article/details/130958475