Residual network implementation based on keras - taking fashion mnist data set classification as an example

Residual network implementation based on keras - taking fashion mnist data set classification as an example

Preface

Recently, I have been learning about residual networks, and tried to use keras to build my own residual network to complete some deep learning tasks. The following are the learning results of the past few days.

residual network

Generally speaking, the deeper the neural network, the better it will perform in feature extraction and recognition of data, but at the same time, it will also face the phenomenon of gradient disappearance or gradient explosion. Therefore, Kaiming He et al. proposed a residual network structure in the paper "Deep residual Learning for Image Recognition", which effectively solved the gradient disappearance or gradient explosion phenomenon after the network is deepened, and used small convolution kernels in the residual network to make The calculation amount of model training is greatly reduced.

residual block

The basic unit of the residual network is the residual block, as shown below.
Insert image description here
It consists of two parts: the direct part and the shortcut part. The data passes through several weight layers in the direct connection part. When I implemented it in code, I used a convolutional layer (without bias), so the shape of the output of the direct connection part is different from the shape of the input and cannot be added directly. Therefore, in the shortcut part, the data also needs to pass through a convolution layer with a convolution kernel of 1 1. A convolution layer with a convolution kernel of 1 1 can change the depth dimension of the data, that is, the number of channels. Of course, I also used two convolution layers with a convolution kernel of 1*1 in the direct connection part to control the depth dimension of the data so that the two parts can be added.

residual network

The above briefly describes the residual block. In fact, there are quite a few variations of the residual block, and sometimes different attempts are required based on different classification tasks. The picture below is the residual block structure used in my code: a
Insert image description here
residual network is built based on the structure of the residual block. The picture below is the 34-layer structure of the residual neural network (first on the right) provided in the original paper and other neural networks. Compared.
Insert image description here
Residual neural networks are very different from other structures, and they are also very different in the number of parameters. For example, the VGG16 model on the left has about 16.9 billion parameters that need to be trained, while the two models on the right only have about 3.6 billion parameters that need to be trained. In addition, in terms of recognition effect, the residual neural network also has great advantages. The recognition accuracy of the simple convolutional neural network structure for the fashion mnist data set is only about 91%, and the training speed is very slow. The residual neural network easily exceeded 92%. The residual neural network I built with only four simple residual blocks easily exceeded 92% after more than ten iterations. Although it is difficult to increase the training later, this undoubtedly demonstrates the superiority of the residual neural network.

fashion-mnist dataset

FashionMNIST is an image dataset that replaces the MNIST handwritten digit set. It is provided by the research arm of Zalando, a German fashion technology company. It covers a total of 70,000 front-facing images of different products from 10 categories. As shown below:Insert image description here

The size, format, and train/test set partitioning of FashionMNIST are exactly the same as the original MNIST. 60000/10000 training and test data division, 28x28 grayscale images. Compared with the handwritten digit training set, this data set has better testability. Handwritten digits can achieve very good classification results on most models, many even exceeding 99%, including some conventional machine learning models, so it is difficult to distinguish their performance.
The calling method is the same as the handwritten digit set, and the data set is integrated in keras.

(x_train, y_train),(x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

Code

The code is shown below. I am using the code written in jupyter notebook, so if necessary, you can copy it to a file and run it yourself.
1. First import the required libraries:

import datetime
import tensorflow as tf
import numpy as np
import keras
from keras.layers import Input, Conv2D, AveragePooling2D, BatchNormalization, Activation, Add, Flatten, Dense, Dropout
from keras.models import Model
from keras.callbacks import ModelCheckpoint, TensorBoard

2. Define weight blocks for easy calling. Convolution is used here:

def conv(channels, strides=1, kernel_size=(3, 3), padding='same'):
    #定义卷积权重块
    return Conv2D(filters=channels, kernel_size=kernel_size, strides=strides, padding=padding,
                use_bias=False, kernel_initializer=tf.random_normal_initializer())

3. Define the residual block. The model diagram is mentioned earlier:

def res_block(inputs, base_channels):
    '''定义残差块'''

    #捷径部分
    residual = inputs
    residual = BatchNormalization()(residual)
    residual = Activation('relu')(residual)
    residual = conv(channels=base_channels, kernel_size=(1, 1))(residual)

    #直连部分
    x = conv(channels=base_channels, kernel_size=(1, 1))(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = conv(channels=base_channels*2, strides=1, kernel_size=(3, 3))(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = conv(channels=base_channels, kernel_size=(1, 1))(x)

    outputs = Add()([x, residual])

    return Activation('relu')(outputs)

4. Define the residual network, which has four residual blocks, and the average pooling layer, tiling layer and fully connected layer are added to the output part:

def ResNet(input_shape, base_channels, classes):
    '''定义残差网络'''
    inputs = Input(shape=input_shape)
    x = conv(channels=base_channels, strides=2, kernel_size=(3, 3))(inputs)

    x = res_block(x, base_channels=base_channels)
    x = res_block(x, base_channels=base_channels*2)
    x = res_block(x, base_channels=base_channels*2)
    x = res_block(x, base_channels=base_channels*4)

    x = AveragePooling2D()(x)
    x = Flatten()(x)
    x = Dense(512, activation='relu')(x)
    outputs = Dense(classes,activation='softmax')(x)

    model = Model(inputs=inputs, outputs=outputs)

    return model

5. Prepare the data set and perform preprocessing:

#准备数据集
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
#数据标准化
x_train, x_test = x_train.astype(np.float32)/255., x_test.astype(np.float32)/255.
#数据整形[None, 28, 28] => [None, 28, 28, 1]
x_train, x_test = np.expand_dims(x_train, axis=3), np.expand_dims(x_test, axis=3)
#标签改为独热编码
y_train_one = tf.one_hot(y_train,depth=10).numpy()
y_test_one = tf.one_hot(y_test,depth=10).numpy()

print(x_train.shape, y_train_one.shape)
print(x_test.shape, y_test_one.shape)

6. Define some hyperparameters:

#类别数
num_classes = 10
#批大小
batch_size = 32
#迭代次数
epochs = 30
#学习率
learning_rate = 0.001
#输入形状
input_shape = (28, 28, 1)
# 项目目录
project_path = "E:\\resnet\\"
# 定义日志目录,必须是启动web应用时指定目录的子目录,建议使用日期时间作为子目录名
log_dir = project_path + "logs\\" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
model_path = project_path + "model_best.h5"

7. Load the model, set the loss function, and print the model:

model = ResNet(input_shape=input_shape, base_channels=16, classes=10)
model.compile(optimizer=keras.optimizers.adam_v2.Adam(learning_rate=learning_rate),
                loss='categorical_crossentropy',
                metrics=['accuracy'])
model.summary()

8. Train the model, save the log, and set up to automatically save the best model:

#设置tensorboard
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
#设置检查点
checkpoint = ModelCheckpoint(filepath=model_path, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
#训练
model.fit(x_train, y_train_one,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=(x_test, y_test_one),
            callbacks=[tensorboard_callback, checkpoint],
            verbose=1)
#评估
scores = model.evaluate(x_test, y_test_one,batch_size=batch_size,verbose=1)
print("最后损失值以及准确率:",scores)

9. Draw some pictures to show:

import matplotlib.pyplot as plt

#标签对应字典
name_dict = {0: 't-shirt',1: 'trouser',2: 'pullover',3: 'dress',4: 'coat',
             5: 'sandal',6: 'shirt',7: 'sneaker',8: 'bag',9: 'ankle boot'}
#绘制结果
model_best = keras.models.load_model('model_best.h5')
plot_image = x_test[10:20]
print(plot_image.shape)
predict_label = np.argmax(model_best.predict(plot_image), axis=1)
true_label = np.argmax(y_test_one[10:20], axis=1)
plot_image = np.reshape(plot_image, (10, 28, 28))

plt.figure(figsize=(25,10))
plt.suptitle('true/predict')

for i in range(1, 11):
    plt.subplot(2, 5, i)
    plt.imshow(plot_image[i-1])
    plt.axis('off')
    plt.title(name_dict[true_label[i-1]]+'/'+name_dict[predict_label[i-1]])

plt.show()

Results display

Model structure

The model structure printed using model.summary() is as follows:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 14, 14, 16)   144         input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_32 (Conv2D)              (None, 14, 14, 16)   256         conv2d_30[0][0]                  
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, 14, 14, 16)   64          conv2d_32[0][0]                  
__________________________________________________________________________________________________
activation_32 (Activation)      (None, 14, 14, 16)   0           batch_normalization_25[0][0]     
__________________________________________________________________________________________________
conv2d_33 (Conv2D)              (None, 14, 14, 32)   4608        activation_32[0][0]              
__________________________________________________________________________________________________
batch_normalization_26 (BatchNo (None, 14, 14, 32)   128         conv2d_33[0][0]                  
__________________________________________________________________________________________________
batch_normalization_24 (BatchNo (None, 14, 14, 16)   64          conv2d_30[0][0]                  
__________________________________________________________________________________________________
activation_33 (Activation)      (None, 14, 14, 32)   0           batch_normalization_26[0][0]     
__________________________________________________________________________________________________
activation_31 (Activation)      (None, 14, 14, 16)   0           batch_normalization_24[0][0]     
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 14, 14, 16)   512         activation_33[0][0]              
__________________________________________________________________________________________________
conv2d_31 (Conv2D)              (None, 14, 14, 16)   256         activation_31[0][0]              
__________________________________________________________________________________________________
add_7 (Add)                     (None, 14, 14, 16)   0           conv2d_34[0][0]                  
                                                                 conv2d_31[0][0]                  
__________________________________________________________________________________________________
activation_34 (Activation)      (None, 14, 14, 16)   0           add_7[0][0]                      
__________________________________________________________________________________________________
conv2d_36 (Conv2D)              (None, 14, 14, 32)   512         activation_34[0][0]              
__________________________________________________________________________________________________
batch_normalization_28 (BatchNo (None, 14, 14, 32)   128         conv2d_36[0][0]                  
__________________________________________________________________________________________________
activation_36 (Activation)      (None, 14, 14, 32)   0           batch_normalization_28[0][0]     
__________________________________________________________________________________________________
conv2d_37 (Conv2D)              (None, 14, 14, 64)   18432       activation_36[0][0]              
__________________________________________________________________________________________________
batch_normalization_29 (BatchNo (None, 14, 14, 64)   256         conv2d_37[0][0]                  
__________________________________________________________________________________________________
batch_normalization_27 (BatchNo (None, 14, 14, 16)   64          activation_34[0][0]              
__________________________________________________________________________________________________
activation_37 (Activation)      (None, 14, 14, 64)   0           batch_normalization_29[0][0]     
__________________________________________________________________________________________________
activation_35 (Activation)      (None, 14, 14, 16)   0           batch_normalization_27[0][0]     
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 14, 14, 32)   2048        activation_37[0][0]              
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 14, 14, 32)   512         activation_35[0][0]              
__________________________________________________________________________________________________
add_8 (Add)                     (None, 14, 14, 32)   0           conv2d_38[0][0]                  
                                                                 conv2d_35[0][0]                  
__________________________________________________________________________________________________
activation_38 (Activation)      (None, 14, 14, 32)   0           add_8[0][0]                      
__________________________________________________________________________________________________
conv2d_40 (Conv2D)              (None, 14, 14, 32)   1024        activation_38[0][0]              
__________________________________________________________________________________________________
batch_normalization_31 (BatchNo (None, 14, 14, 32)   128         conv2d_40[0][0]                  
__________________________________________________________________________________________________
activation_40 (Activation)      (None, 14, 14, 32)   0           batch_normalization_31[0][0]     
__________________________________________________________________________________________________
conv2d_41 (Conv2D)              (None, 14, 14, 64)   18432       activation_40[0][0]              
__________________________________________________________________________________________________
batch_normalization_32 (BatchNo (None, 14, 14, 64)   256         conv2d_41[0][0]                  
__________________________________________________________________________________________________
batch_normalization_30 (BatchNo (None, 14, 14, 32)   128         activation_38[0][0]              
__________________________________________________________________________________________________
activation_41 (Activation)      (None, 14, 14, 64)   0           batch_normalization_32[0][0]     
__________________________________________________________________________________________________
activation_39 (Activation)      (None, 14, 14, 32)   0           batch_normalization_30[0][0]     
__________________________________________________________________________________________________
conv2d_42 (Conv2D)              (None, 14, 14, 32)   2048        activation_41[0][0]              
__________________________________________________________________________________________________
conv2d_39 (Conv2D)              (None, 14, 14, 32)   1024        activation_39[0][0]              
__________________________________________________________________________________________________
add_9 (Add)                     (None, 14, 14, 32)   0           conv2d_42[0][0]                  
                                                                 conv2d_39[0][0]                  
__________________________________________________________________________________________________
activation_42 (Activation)      (None, 14, 14, 32)   0           add_9[0][0]                      
__________________________________________________________________________________________________
conv2d_44 (Conv2D)              (None, 14, 14, 64)   2048        activation_42[0][0]              
__________________________________________________________________________________________________
batch_normalization_34 (BatchNo (None, 14, 14, 64)   256         conv2d_44[0][0]                  
__________________________________________________________________________________________________
activation_44 (Activation)      (None, 14, 14, 64)   0           batch_normalization_34[0][0]     
__________________________________________________________________________________________________
conv2d_45 (Conv2D)              (None, 14, 14, 128)  73728       activation_44[0][0]              
__________________________________________________________________________________________________
batch_normalization_35 (BatchNo (None, 14, 14, 128)  512         conv2d_45[0][0]                  
__________________________________________________________________________________________________
batch_normalization_33 (BatchNo (None, 14, 14, 32)   128         activation_42[0][0]              
__________________________________________________________________________________________________
activation_45 (Activation)      (None, 14, 14, 128)  0           batch_normalization_35[0][0]     
__________________________________________________________________________________________________
activation_43 (Activation)      (None, 14, 14, 32)   0           batch_normalization_33[0][0]     
__________________________________________________________________________________________________
conv2d_46 (Conv2D)              (None, 14, 14, 64)   8192        activation_45[0][0]              
__________________________________________________________________________________________________
conv2d_43 (Conv2D)              (None, 14, 14, 64)   2048        activation_43[0][0]              
__________________________________________________________________________________________________
add_10 (Add)                    (None, 14, 14, 64)   0           conv2d_46[0][0]                  
                                                                 conv2d_43[0][0]                  
__________________________________________________________________________________________________
activation_46 (Activation)      (None, 14, 14, 64)   0           add_10[0][0]                     
__________________________________________________________________________________________________
average_pooling2d_2 (AveragePoo (None, 7, 7, 64)     0           activation_46[0][0]              
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 3136)         0           average_pooling2d_2[0][0]        
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 512)          1606144     flatten_2[0][0]                  
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 10)           5130        dense_3[0][0]                    
==================================================================================================
Total params: 1,749,210
Trainable params: 1,748,154
Non-trainable params: 1,056
__________________________________________________________________________________________________

The model has about 1.7 million parameters that need to be trained. If the parameters of the last fully connected layer in the model are deleted, the entire model only has about 100,000 parameters that need to be trained, but the accuracy will be lower than that.
The model diagram is as follows:
Insert image description here

forecast result

The accuracy of the final model reached 92.38%, and it was achieved in only 6 iterations.

Insert image description here
However, the subsequent accuracy in the training set has been improved.
The following is the prediction category extracted from ten pictures in the test set. You can see that the effect is very good:
Insert image description here
in the previous test, the coat in the second row and third column was often recognized as shirt, and this model was the only one that predicted accurately. Below is some data with prediction errors. You can see that many pictures are difficult to identify with the naked eye.Insert image description here

reference

Detailed explanation of the residual network
Fashion-MNIST: Introduction to Keras, an image data set that replaces the MNIST handwritten digit set,
and the construction of a residual network

Guess you like

Origin blog.csdn.net/qq_44725872/article/details/125373317