[Machine Learning] The most classic case: handwritten digit recognition (complete process: DNN/CNN structure design, model parameter storage, breakpoint continuation training, acc/loss visualization)

Environment: python3.7+TensorFlow
complete code, model and parameters, detailed documents see: handwritten digit recognition complete code + detailed documents + model parameters

1 Overview

1.1 Tasks

The purpose of the handwritten digit recognition task is to complete the conversion of handwritten digits to digital characters, and the neural network model can be used for feature extraction and pattern recognition of handwritten digit images.

1.2 Dataset

MNISTIt is a classic handwritten digit dataset and one of the most commonly used datasets in handwritten digit recognition tasks. Dataset size: Contains 60000training images and 10000testing images, each image is 28x28a grayscale image of pixel size, that is, the number of channels is 1. Its importance lies in that it can help machine learning algorithms learn the characteristics of handwritten digits and perform handwritten digit recognition.

1.3 Solutions

Use TensorFlowthe framework, which provides the MNISTdataset API, use tf.keras to build a convolutional neural network architecture, perform model training, save, and visualize training results, and realize feature extraction and pattern recognition of handwritten digital images.

2 solutions

import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
import os

2.1 Load and view training set/test set

tf.kerasMNISTDatasets are provided APIand can be loaded directly:

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Visualize x_trainthe first element of the training set, the picture is a grayscale image, and the image size is 28x28x1.

plt.imshow(x_train[0], cmap='gray')
plt.show()

View the shapes of the training and test sets:

# 查看训练集x, y的形状
print("x_train.shape:\n", x_train.shape)
print("y_train.shape:\n", y_train.shape)

# 查看测试集x,y的形状
print("x_test.shape:\n", x_test.shape)
print("y_test.shape:\n", y_test.shape)

insert image description here

Since the image data is a single-channel number, such as x_trainthe dimension is [60000, 28, 28], but TensorFlowthe dimension of the input training data is required to be 4(important), the following processing needs to be performed on the data and normalized.

#在TensorFlow中做卷积的时候需要把数据变成4维的格式
#4个维度:数据数量,图片高度,图片宽度,图片通道数
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

#数据归一化
x_train, x_test = x_train / 255.0, x_test / 255.0

2.2 CNN network structure design

If DNN is used, the first two layers of the following structure can be removed.

Conv1 Convolution kernel: 5x5x1x16
Conv2 Convolution kernel: 5x5x16x32
flatten
dense1 Number of neurons: 128
Dense2 Number of neurons: 10
#--------------------------二、CNN网络结构设计----------------------------------#
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(
        # input_shape=(28, 28, 1),  # 可省略
        kernel_size=5,
        filters=16, #卷积核大小:5*5*1,个数:16
        strides=1, #步长1
        padding="same", #填充:全0填充same/默认不适使用:valid
        activation="relu"  #激活函数:relu/sigmoid/...
        ),
    #输出:28*28*16

    tf.keras.layers.MaxPool2D(2,2), #池化
    # 输出:14*14*16

    tf.keras.layers.Conv2D(
        kernel_size=5,
        filters=32, #卷积核大小5*5*16 个数:32
        strides=1,
        padding="same",
        activation="relu"
    ),
    #输出:14*14*32

    tf.keras.layers.MaxPool2D(2, 2),
    #输出:7*7*32

    tf.keras.layers.Flatten(),  #展平
    tf.keras.layers.Dense(128, activation="relu"), #全连接层1:通常128,64个神经元 激活:relu
    tf.keras.layers.Dense(10, activation="softmax"), #全连接层2:神经元个数为最后输出维度 激活:softmax多分类
])

insert image description here

2.3 Setting the optimizer and loss function

Optimizer choice: Adam.

Loss function: cross-entropy loss, used for classification.

#-------------------------------三、设置优化器、损失函数-----------------------------------#
model.compile(optimizer='adam', #优化器Adam
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), #损失函数:交叉熵损失
              metrics=['sparse_categorical_accuracy']) #准确率:accuracy:y_/y均为数值;categorical_accuracy:y_/y都是独热码;sparse_categorical_accuracy:y_/y是数值+独热码

2.4 Access model, continuous training

For each training one epoch, save the model parameters once.

tf.keras.callbacks.ModelCheckpoint()The function can save the weights and biases of the model, as well as training progress and other information, so that after the training is interrupted, the previous model state can be restored to continue training, ie 断点续训.

Parameter meaning: filepathThe parameter specifies the file path to save the model configuration, save_weights_onlythe parameter specifies whether to save only the model weights, and save_best_onlythe parameter specifies whether to save only the best model results. By default, this callback saves the model configuration after each epoch.

#-------------------------------四、Add:存取模型,断点续训-----------------------#
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'): #index:
    print('------------------------load the model----------------------------')
    model.load_weights(checkpoint_save_path) #加载模型

cp_callback = tf.keras.callbacks.ModelCheckpoint( #保存模型cp_callback
    filepath=checkpoint_save_path,
    save_weights_only=True, #只保存weight
    save_best_only=True #只保存最好的一次
)

#----------------------------五、训练模型-------------------------------#
history = model.fit(x_train, y_train, #训练集数据和标签
          batch_size=32, epochs=5, #批次大小、轮次
          validation_data=(x_test, y_test), #验证集数据
          validation_freq=1, #多少个epoch测试一次
          callbacks=[cp_callback] #已有模型!
          )

#------------------------------六、打印模型结构--------------------------------#
model.summary()

insert image description here

2.5 Parameter extraction and saving

Save the trainable parameters in the model to a file weights.txt.

model.trainable_variablesis a list containing the trainable variables in the model.

#----------------------------7. Parameter extraction and saving --------------- ----------------#
np.set_printoptions(threshold=np.inf) #Set the output format, all display (threshold indicates how much is exceeded, and the display is omitted)

# print(model.trainable_variables)
file = open('./weights.txt', 'w') #设置weights保存路径
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()#

2.6 acc/loss visualization

#-----------------------------八、acc/loss可视化--------------------------------#
#训练集acc/loss
acc = history.history['sparse_categorical_accuracy']
loss = history.history['loss']
#测试集acc/loss
val_acc = history.history['val_sparse_categorical_accuracy']
val_loss = history.history['val_loss']

#acc曲线
plt.subplot(1,2,1)
plt.plot(acc, label='Training Acc')
plt.plot(val_acc, label='Validation Acc')
plt.title('Training and Validation ACC')
plt.legend()

#loss曲线
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

matplotlibDraw the curve of the training set and test set through the library acc/loss, as shown in the figure below:
test set: acc>0.98; loss<0.05.

img src="C:\Users\郑丽娟\AppData\Roaming\Typora\typora-user-images\image-20230427165358251.png" alt="image-20230427165358251" style="zoom:67%;" />

3 summary

Mastered TensorFlowthe use of the framework through handwritten digit recognition tasks, mastered tf.keras APIthe use of for, mastered how to tf.kerasbuild a convolutional neural network model, mastered how to save model parameters, how to set breakpoints to continue training, mastered how to use the matplotliblibrary to train and visualize test results.

Guess you like

Origin blog.csdn.net/weixin_44820505/article/details/130474463