[Deep Learning] Experiment 14 Use CNN to complete MNIST handwriting recognition (TensorFlow)

Using CNN to complete MNIST handwriting recognition (TensorFlow)

CNN (Convolutional Neural Network, Convolutional Neural Network) is a relatively common neural network model, which is usually used in image recognition, speech recognition and other fields. Compared with traditional neural network models, CNN has obvious advantages in processing data such as images. Its core idea is to extract features in images through operations such as convolution and pooling to achieve tasks such as image recognition.

The basic structure of CNN consists of convolutional layer (Convolutional Layer), pooling layer (Pooling Layer), fully connected layer (Fully Connected Layer), etc. Among them, the convolution layer is the core part of CNN, which extracts features from the input image through the convolution kernel. Each convolution kernel is a small matrix, which can be regarded as a specific filter. The convolution kernel slides on the input image, and the pixel value of each position is compared with the weight of the corresponding position in the convolution kernel. The sum of the products finally results in a new feature map. Through multiple convolution operations, high-level features in the image can be gradually extracted.

After the convolutional layer, a pooling layer is usually added to reduce the dimensionality of the feature map. Common pooling methods include maximum pooling and average pooling. Maximum pooling takes the maximum value in a specific input area as the pooled value, while average pooling takes the average value of the input area as the pool. transformed value. Through the pooling operation, the dimensionality of the feature map can be reduced, the computational complexity can be reduced, and the robustness of the model can also be improved to avoid affecting the output results of the model due to changes in some details in the input data.

Finally, the fully connected layer converts the feature map obtained after convolution and pooling into a one-dimensional vector. Complex image classification, target detection and other tasks can be achieved through the combination of multiple fully connected layers. When training a CNN, the backpropagation algorithm is usually used to optimize the parameters in the network. The gradient of the loss function for each parameter can be calculated through the backpropagation algorithm, thereby updating the parameters.

In general, CNN is a very effective neural network model and has been widely used in image processing and other fields. However, for beginners, building and training a CNN model requires relatively high mathematics and programming skills, and a certain foundation is required to master it.

1. Import the TensorFlow library

# Tensorflow提供了一个类来处理MNIST数据
# 导入相关库
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import time
import warnings
warnings.filterwarnings('ignore')

2. Dataset

# 载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
mnist
   WARNING:tensorflow:From <ipython-input-2-574fb576d2f2>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
   Instructions for updating:
   Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
   Instructions for updating:
   Please write your own downloading logic.
   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
   Instructions for updating:
   Please use tf.data to implement this functionality.
   Extracting MNIST_data/train-images-idx3-ubyte.gz
   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
   Instructions for updating:
   Please use tf.data to implement this functionality.
   Extracting MNIST_data/train-labels-idx1-ubyte.gz
   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
   Instructions for updating:
   Please use tf.one_hot on tensors.
   Extracting MNIST_data/t10k-images-idx3-ubyte.gz
   Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
   WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
   Instructions for updating:
   Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
   




   Datasets(train=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x7f05f902e2b0>, validation=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x7f05f902e358>, test=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x7f05f902e320>)
# 设置批次大小
batch_size = 50
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
n_batch
1100
# 定义初始化权值函数
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)
# 定义初始化偏置函数
def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)
"""
strides=[b,h,w,c]
b表示在样本上的步长默认为1,也就是每一个样本都会进行运算。
h表示在高度上的默认移动步长为1,这个可以自己设定,根据网络的结构合理调节。
w表示在宽度上的默认移动步长为1,这个同上可以自己设定。
c表示在通道上的默认移动步长为1,这个表示每一个通道都会进行运算
"""
# 卷积层
def conv2d(input, filter):
    return tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
"""
ksize=[b,h,w,c],通常为[1,2,2,1]
b表示在样本上的步长默认为1,也就是每一个样本都会进行运算。
h表示在高度上的默认移动步长为1,这个可以自己设定,根据网络的结构合理调节。
w表示在宽度上的默认移动步长为1,这个同上可以自己设定。
c表示在通道上的默认移动步长为1,这个表示每一个通道都会进行运算
"""
# 池化层
def max_pool_2x2(value):
    return tf.nn.max_pool(value, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
# 输入层
# 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784]) # 28*28
y = tf.placeholder(tf.float32, [None, 10])
# 改变x的格式转为4维的向量[batch,in_hight,in_width,in_channels]
x_image = tf.reshape(x, [-1, 28, 28, 1])

3. Convolution, excitation, and pooling operations

# 初始化第一个卷积层的权值和偏置
# MNIST使用的是灰度图像,每个像素点只需要一个数值,因此这里通道数为1
# 5*5的采样窗口,32个卷积核从1个平面抽取特征
w_conv1 = weight_variable([5, 5, 1, 32])
# 每一个卷积核一个偏置值
b_conv1 = bias_variable([32])
# 把x_image和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
# 进行max_pooling 池化层 14*14*32
h_pool1 = max_pool_2x2(h_conv1)
# 初始化第二个卷积层的权值和偏置
# 5*5的采样窗口,64个卷积核从32个平面抽取特征
w_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
# 把第一个池化层结果和权值向量进行卷积,再加上偏置值,然后应用于relu激活函数
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
# 池化层 7*7*64
h_pool2 = max_pool_2x2(h_conv2)  

After the first convolution, the 28x28 image is still 28x28. After the first pooling, it becomes 14x14. After the second convolution, it becomes 14x14.
After the second pooling, it becomes 7x7.
After the above operation, 64 7x7 planes are obtained.

4. Fully connected layer

# 初始化第一个全连接层的权值
# 经过池化层后有7*7*64个神经元,全连接层有128个神经元
w_fc1 = weight_variable([7 * 7 * 64, 128])
# 128个节点
b_fc1 = bias_variable([128])
# 把池化层2的输出扁平化为1维
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
# 求第一个全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
h_pool2_flat.shape, h_fc1.shape
(TensorShape([Dimension(None), Dimension(3136)]),
 TensorShape([Dimension(None), Dimension(128)]))
# keep_prob: float类型,每个元素被保留下来的概率,设置神经元被选中的概率,在初始化时keep_prob是一个占位符
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
WARNING:tensorflow:From <ipython-input-14-a22383db216d>:3: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
# 初始化第二个全连接层
W_fc2 = weight_variable([128, 10])
b_fc2 = bias_variable([10])

5. Output layer

# 计算输出
prediction = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
# 交叉熵代价函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
# 使用AdamOptimizer进行优化
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
# 结果存放在一个布尔列表中(argmax函数返回一维张量中最大的值所在的位置)
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
# 求准确率(tf.cast将布尔值转换为float型)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
WARNING:tensorflow:From <ipython-input-17-ef3c12a7f7c4>:2: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

6. Training model

# 创建会话
with tf.Session() as sess:
    start_time = time.clock()
    # 初始化变量
    sess.run(tf.global_variables_initializer()) 
    print('开始训练 ----------')
    # 训练10次
    for epoch in range(10):
        print("Test" + str(epoch) + " :")
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # 进行迭代训练
            sess.run(train_step, feed_dict={
    
    x: batch_xs, y: batch_ys, keep_prob: 0.7})  
            print('第' + str(batch) + '批训练')
        # 测试数据计算出准确率
        acc = sess.run(accuracy, feed_dict={
    
    x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})
        print('Iter' + str(epoch) + ',Testing Accuracy=' + str(acc))
    end_time = time.clock()
    # 输出运行时间
    print('Running time:%s Second' % (end_time - start_time)) 
开始训练 ----------
Test0 :
第0批训练
第1批训练
第2批训练
第3批训练
第4批训练
第5批训练
第6批训练
第7批训练
第8批训练
第9批训练
第10批训练
第11批训练
第12批训练
第13批训练
第14批训练
第15批训练
第16批训练
第17批训练
第18批训练
第19批训练
第20批训练
第21批训练
第22批训练
第23批训练
第24批训练
第25批训练
第26批训练
第27批训练
第28批训练
第29批训练
第30批训练
第31批训练
第32批训练
第33批训练
第34批训练
第35批训练
第36批训练
第37批训练
第38批训练
第39批训练
第40批训练
第41批训练
第42批训练
第43批训练
第44批训练
第45批训练
第46批训练
第47批训练
第48批训练
第49批训练
第50批训练
第51批训练
第52批训练
第53批训练
第54批训练
第55批训练
第56批训练
第57批训练
第58批训练
第59批训练
第60批训练
第61批训练
第62批训练
第63批训练
第64批训练
第65批训练
……

Attachment: series of articles

serial number Article directory direct link
1 Boston house price forecast https://want595.blog.csdn.net/article/details/132181950
2 Iris dataset analysis https://want595.blog.csdn.net/article/details/132182057
3 Feature processing https://want595.blog.csdn.net/article/details/132182165
4 Cross-validation https://want595.blog.csdn.net/article/details/132182238
5 Constructing a Neural Network Example https://want595.blog.csdn.net/article/details/132182341
6 Complete linear regression using TensorFlow https://want595.blog.csdn.net/article/details/132182417
7 Complete logistic regression using TensorFlow https://want595.blog.csdn.net/article/details/132182496
8 TensorBoard case https://want595.blog.csdn.net/article/details/132182584
9 Complete linear regression using Keras https://want595.blog.csdn.net/article/details/132182723
10 Complete logistic regression using Keras https://want595.blog.csdn.net/article/details/132182795
11 Complete cat and dog recognition using Keras pre-trained model https://want595.blog.csdn.net/article/details/132243928
12 Training models using PyTorch https://want595.blog.csdn.net/article/details/132243989
13 Use Dropout to suppress overfitting https://want595.blog.csdn.net/article/details/132244111
14 Using CNN to complete MNIST handwriting recognition (TensorFlow) https://want595.blog.csdn.net/article/details/132244499
15 Using CNN to complete MNIST handwriting recognition (Keras) https://want595.blog.csdn.net/article/details/132244552
16 Using CNN to complete MNIST handwriting recognition (PyTorch) https://want595.blog.csdn.net/article/details/132244641
17 Using GAN to generate handwritten digit samples https://want595.blog.csdn.net/article/details/132244764
18 natural language processing https://want595.blog.csdn.net/article/details/132276591

Guess you like

Origin blog.csdn.net/m0_68111267/article/details/132244499