mnist handwritten digit recognition (TensorFlow-GPU) -----------principle and code

This article mainly recognizes the confused numbers in the mnist handwritten data set, and establishes a relatively simple machine learning model based on Softmax Regression.

Through this article, you can have a general understanding of neural networks, and you can also master simple image recognition techniques. The source of the pictures in this chapter is an open source training data set (mnist)

We are divided into the following parts to proceed:

  1. Import the data set.
  2. Analyze the characteristics of mnist samples to define variables.
  3. Build the model.
  4. Train the model and output intermediate state parameters.
  5. Test the model.
  6. Save the model.
  7. Read the model.

1. Import a dataset of handwritten pictures

(1) mnist data set

    The mnist data set contains various handwritten digital pictures, as shown in Figure 1 below:

    It contains the label corresponding to each picture, tells us which number is this number, and label 5,0,4,1 on the four pictures below

     

   (2) Download the mnist dataset using TensorFlow code

     Download the minst data set through the library provided by TensorFlow:

# -*- coding: utf-8 -*-
# !/usr/bin/env python
# @Time    : 2019/5/17 17:03
# @Author  : xhh
# @Desc    :  minist数据集下载
# @File    : mnist_data_load.py
# @Software: PyCharm
from tensorflow.examples.tutorials.mnist import  input_data
import pylab

mnist = input_data.read_data_sets("MINST_daya/", one_hot=True)
print("输入数据:",mnist.train.images)
print("数据的shape:",mnist.train.images.shape)

# 展示数据集中的一张图片
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()

Run the above code, it will automatically download the data set, and unzip the text file to the MNIST_daya folder in the same directory where it is currently located

note:

The one_hot=True in the code means that the sample label is converted into one_hot encoding .

onn_hot encoding:

Suppose there are 10 types of codes. The one_hot of 0 is 1000000000, the one_hot of 1 is 0100000000, the one_hot of 2 is 0010000000...and so on, only one bit is 1, and the position where 1 is located represents the type.

operation result:

                 


      Figure II

Seeing that the picture information of the training set printed above is a matrix with 55,000 rows and 784 columns, that is, there are 55,000 pictures in the training set, and each picture is 1 row of 784 (28*28) column data. Each value in brackets represents a pixel.

The mnist data set is shown in the color picture in Figure 2 above. It is a color picture. It has 3 channels and is composed of RGB (red, yellow, blue). The black and white one is a single-channel picture with a value ranging from 0 to 255. The number between represents the depth of its color.

(3) The composition of the mnst data set

      In the MNIST training data set, mnist.train.images is a tensor with shape [55000, 784]. Among them, the first dimension number is used to index pictures, and the second dimension number is used to index
pixels in each picture. Each element in this tensor represents the intensity value of a pixel in a certain picture, and the value is between 0 and 255

 MNIST contains 3 data sets:

The first one is the training data set, the other two are the test data set (mnist.test) and the validation data set (mnist.validation)

The following figure shows the downloaded mnist data set compression package:

2. Analyze the characteristics of pictures and define variables

       Since the input picture is a 550,000×784 matrix, first create a placeholder x of [None, 784] and a placeholder y of [None, 10], and then use the feed mechanism to
input the picture and label .

code show as below:

from tensorflow.examples.tutorials.mnist import  input_data
import pylab
import tensorflow as tf

mnist = input_data.read_data_sets("MINST_daya/", one_hot=True)

tf.reset_default_graph()

# 定义占位符
x = tf.placeholder(tf.float32, [None, 784])  # mnist data 维度28*28=784
y = tf.placeholder(tf.float32, [None, 10])  #0-9 数字 ==>10class

     When defining placeholders, None in x and y means that the first dimension of this tensor can be of any length. x represents that any number of mnist images can be input, and each image can be expanded into a 784-dimensional vector.

Three, build the model

(1) Define learning parameters

    In TensorFlow, use Variable to define learning parameters. The model also needs weights and biases, which are collectively called learning parameters. A Variable represents a modifiable tensor, which is defined in the TensorFlow graph (an execution task), which itself is also a variable. The learning parameters defined by Variable can be used to calculate the input value, and can also be modified in the calculation.


# 定义学习参数
# 设置模型的权重
W = tf.Variable(tf.random_normal([784, 10]))  # W的维度是[784, 10]
b = tf.Variable(tf.zeros([10]))

Here assign different initial values ​​to tf.Variable to create different parameters. Generally, W is set to a random value, and b is set to 0.

(2) Define the output node

With the inputs and model parameters, you can then string them together to build a real model.


# 定义输出节点, 构建模型
pred = tf.nn.softmax(tf.matmul(x, W) + b)  # softmax分类

        First, use tf.matmul(x, W) to represent x multiplied by W, where x is a two-dimensional tensor with multiple inputs. Then add b and enter their sum into the tf.nn.softmax function. So far, the structure of forward propagation has been constructed. That is to say, as long as the parameters in the model are appropriate, through specific data input, we can get the classification we want.

(3) Define the back propagation structure


# 定义反向传播的结构,编译训练模型,得到合适的参数
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

# 参数设置
learning_rate = 0.01
# 使用梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

Understanding of backpropagation structure:

(1) Perform a cross-entropy operation on the generated pred and the sample label y, and then take the average value.
(2) Regard this result as the error of a forward propagation, and find the offset of b and W that can minimize this error through the optimization method of gradient descent.
(3) Update b and W to adjust them to appropriate parameters. The whole process is to continuously make the loss value (error value cost) smaller. Because the smaller the loss value, the closer the output result is to the label data. When the cost is as small as our needs, b and W at this time are the appropriate values ​​for training.

(4) Train the model and output intermediate state parameters, and save and test the model

Define status parameters

training_epochs = 25   # 将整个训练样本迭代25次
batch_size = 100    # 在训练过程中每次随机抽取100条数据进行训练
display_step = 1   # 迭代的步数
saver = tf.train.Saver()
model_path = "mnist/521model.ckpt"

Start the session:

# 开始训练
with tf.Session()  as sess:
    # 初始化节点
    sess.run(tf.global_variables_initializer())

    # 启动循环开始训练
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
        # 遍历全部的数据集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # 运行和优化节点的损失函数值
            _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs,
                                                          y: batch_ys})
            # 计算平均损失值
            avg_cost += c / total_batch

        # 显示训练中的详细信息
        if (epoch+1) % display_step ==0:
            print("Epoch:","%04d"%(epoch+1), "cost=",'{:.9f}'.format(avg_cost))

    print("训练成功!!")

    # 模型测试
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print("准确度:",accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))

    # 保存模型的权重
    save_path = saver.save(sess, model_path)
    print("模型文件在:%s"%save_path)

The final running result:

Epoch: 0001 cost= 8.658398746
Epoch: 0002 cost= 4.599675331
Epoch: 0003 cost= 3.098299387
Epoch: 0004 cost= 2.414841038
Epoch: 0005 cost= 2.031551510
Epoch: 0006 cost= 1.787429208
Epoch: 0007 cost= 1.617599975
Epoch: 0008 cost= 1.491779541
Epoch: 0009 cost= 1.394358738
Epoch: 0010 cost= 1.316281419
Epoch: 0011 cost= 1.251967654
Epoch: 0012 cost= 1.197913221
Epoch: 0013 cost= 1.151722029
Epoch: 0014 cost= 1.111743248
Epoch: 0015 cost= 1.076424035
Epoch: 0016 cost= 1.045415161
Epoch: 0017 cost= 1.017401275
Epoch: 0018 cost= 0.992323116
Epoch: 0019 cost= 0.969426456
Epoch: 0020 cost= 0.948599738
Epoch: 0021 cost= 0.929346439
Epoch: 0022 cost= 0.911827402
Epoch: 0023 cost= 0.895336545
Epoch: 0024 cost= 0.880129020
Epoch: 0025 cost= 0.865876571

Save the model at the end: the
model file is at: mnist/521model.ckpt

4. Read the model and test it

# 读取模型
print("启动第二次session")
with tf.Session() as sess2:
    # 初始化参数
    sess2.run(tf.global_variables_initializer())
    #从保存的模型中获取权重
    saver.restore(sess2, model_path)

    # 测试 model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print("准确度:",accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))

    output = tf.argmax(pred, 1)
    batch_xs, batch_ys = mnist.train.next_batch(2)
    outputval, predv = sess2.run([output, pred], feed_dict={x:batch_xs})
    print(outputval, pred, batch_ys)

    im = batch_xs[0]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()

    im = batch_xs[1]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()

The final running result:

Accuracy: 0.8296
[0 2] Tensor("Softmax:0", shape=(?, 10), dtype=float32) [[1. 0. 0. 0. 0. 0. 0. 0. 0. ]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]]

We can see that when two numbers are randomly selected in the test set for prediction, the prediction results of 0 and 2 are the same as the label, and the corresponding total accuracy is 80%.

At this point, the recognition of mnist handwritten data set is completed.

 

Code acquisition: Scan the QR code below to follow the official account "python crawler scrapy", and reply to the mnist code in the background to get it~~

Guess you like

Origin blog.csdn.net/weixin_39121325/article/details/90297640