深度学习基础 - MNIST实验(Tensorflow-CNN)
本文的完整代码托管在我的Github PnYuan - Practice-of-Machine-Learning - MNIST_tensorflow_demo,欢迎交流。
1.任务背景
这里,我们拟通过搭建卷积神经网络(CNN)来完成MNIST手写数字识别任务,关于MNIST任务的相关内容可参考前文深度学习基础 - MNIST实验(tensorflow+Softmax)或深度学习基础 - MNIST实验(tensorflow+MLP)。
2.实验过程
实验参考代码:python + tensorflow: cnn_demo.py & cnn_demo_self_test.py
实验分三步进行:
- 参考LeNet-5,搭建适用于该任务的CNN模型,开发实现基于tensorflow;
- 加载MNIST数据集,配置超参数,进行训练与测试,分析效果;
- 加载自制手写图片,采用训练好的CNN进行识别,分析效果;
2.1.CNN建模
LeNet-5是Y.LeCun等人早期所设计的一种CNN,是经典的神经网络架构之一,如下图所示:(参考原文献)
本实验采用python-tensorflow实现LeNet-5,其建模代码样例如下:
'''construction of leNet-5 model'''
def lenet_5_forward_propagation(X):
"""
@note: construction of leNet-5 forward computation graph:
CONV1 -> MAXPOOL1 -> CONV2 -> MAXPOOL2 -> FC3 -> FC4 -> SOFTMAX
@param X: input dataset placeholder, of shape (number of examples (m), input size)
@return: A_l, the output of the softmax layer, of shape (number of examples, output size)
"""
# reshape imput as [number of examples (m), weight, height, channel]
X_ = tf.reshape(X, [-1, 28, 28, 1]) # num_channel = 1 (gray image)
### CONV1 (f = 5*5*1, n_f = 6, s = 1, p = 'same')
W_conv1 = weight_variable(shape = [5, 5, 1, 6])
b_conv1 = bias_variable(shape = [6])
# shape of A_conv1 ~ [m,28,28,6]
A_conv1 = tf.nn.relu(tf.nn.conv2d(X_, W_conv1, strides = [1, 1, 1, 1], padding = 'SAME') + b_conv1)
### MAXPOOL1 (f = 2*2*1, s = 2, p = 'same')
# shape of A_pool1 ~ [m,14,14,6]
A_pool1 = tf.nn.max_pool(A_conv1, ksize = [1, 2, 2, 1], strides=[1, 2, 2, 1], padding = 'SAME')
### CONV2 (f = 5*5*1, n_f = 16, s = 1, p = 'same')
W_conv2 = weight_variable(shape = [5, 5, 6, 16])
b_conv2 = bias_variable(shape = [16])
# shape of A_conv2 ~ [m,10,10,16]
A_conv2 = tf.nn.relu(tf.nn.conv2d(A_pool1, W_conv2, strides = [1, 1, 1, 1], padding = 'VALID') + b_conv2)
### MAXPOOL2 (f = 2*2*1, s = 2, p = 'same')
# shape of A_pool2~ [m,5,5,16]
A_pool2 = tf.nn.max_pool(A_conv2, ksize = [1, 2, 2, 1], strides=[1, 2, 2, 1], padding = 'SAME')
### FC3 (n = 120)
# flatten the volumn to vector
A_pool2_flat = tf.reshape(A_pool2, [-1, 5*5*16])
W_fc3 = weight_variable([5*5*16, 120])
b_fc3 = bias_variable([120])
# shape of A_fc3 ~ [m,120]
A_fc3 = tf.nn.relu(tf.matmul(A_pool2_flat, W_fc3) + b_fc3)
### FC4 (n = 84)
W_fc4 = weight_variable([120, 84])
b_fc4 = bias_variable([84])
# shape of A_fc4 ~ [m, 84]
A_fc4 = tf.nn.relu(tf.matmul(A_fc3, W_fc4) + b_fc4)
# Softmax (n = 10)
W_l = weight_variable([84, 10])
b_l = bias_variable([10])
# shape of A_l ~ [m,10]
A_l=tf.nn.softmax(tf.matmul(A_fc4, W_l) + b_l)
return A_l
2.2.训练与测试
设置优化策略及相关超参数(如learning_rate
、num_epochs
、mini-batch size
等),进行训练,经过一段时间的训练,得出的accuracy
结果如下:
Train Accuracy: 0.9920
Valid Accuracy: 0.9896
Test Accuracy: 0.9881
同时该训练期间,指标accuracy
和cost
的变化过程如下图示:
可以看出,此处CNN(LeNet-5)已经取得了不错的结果(≈99%的测试准确率)。而通过观察训练曲线变化趋势,猜测随着迭代的继续,模型效果还可继续提升。
2.3.实测
接下来验证该CNN模型在生活场景下的泛化效果,笔者此处在实验室即兴写了若干待识别数字,示意如下:
采用之前所训练的CNN,得出预测结果示意如下:
结果中出现了一些识别错误,初步猜测是由数据分布的差异所引起。可以考虑在图像训练和测试时,先采用更多的预处理手段(如灰度归一化、对比度增强、阈值分割…),从而使分布接近。降低模型迁移难度。
3.实验小结
本文采用CNN模型进行mnist手写数字识别任务,取得了很好的效果(99%的测试准确率)。同时采用训练好的模型识别了实际场景中的数字,体现了一定的识别效果。
4.参考资料
官方参考:
CNN模型:
开发辅助: