TensorFlow2.0 fully connected neural network handwritten digit recognition combat

1. Configure library files

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np

2. Load the data set

tensorflow loading data set is more convenient, the following is the online and local loading methods:

# 下面一行是在线加载方式
# mnist = tf.keras.datasets.mnist
# 下面两行是加载本地的数据集
datapath  = r'E:\Pycharm\project\project_TF\.idea\data\mnist.npz'
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(datapath)

x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)   #归一化

3. Establish a fully connected network model

model = tf.keras.Sequential([ # 3 个非线性层的嵌套模型
    tf.keras.layers.Flatten(),  #将多维数据打平
    tf.keras.layers.Dense(784, activation='relu'),	# 128也行
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')	# softmax分类
])
# 打印模型
model.build((None,784,1))	# 这里要先build,告诉模型数据输入格式
print(model.summary())

The printed network structure is as follows:

4. Model compilation and training

model.compile(optimizer='adam',	# 优化器
              loss='sparse_categorical_crossentropy',	# 交叉熵损失函数
              metrics=['accuracy'])	# 标签
# 训练模型
model.fit(x_train, y_train, epochs=10,verbose=1) # verbose为1表示显示训练过程

5. Model testing

There are two ways to give accuracy. Based on this, you can understand the difference between model.evaluate() and model.predict() and deepen your understanding.
The first (recommended):

#这里是测试模型
val_loss, val_acc = model.evaluate(x_test, y_test) # model.evaluate是输出计算的损失和精确度
print('First test Loss:{:.6f}'.format(val_loss)

The second type (to deepen understanding):

#测试模型方式二
acc_correct = 0
predictions = model.predict(x_test)     # model.perdict是输出预测结果
for i in range(len(x_test)):
    if (np.argmax(predictions[i]) == y_test[i]):    # argmax是取最大数的索引,放这里是最可能的预测结果
        acc_correct += 1
print('Second test accuracy:{:.6f}'.format(acc_correct*1.0/(len(x_test))))

At this point, the complete fully connected neural network program is completed.

6. The results of the program

The accuracy of the training concentration reached 0.9967, and the accuracy of the test concentration reached 0.9787. The use of convolutional neural networks can further increase the accuracy.

7. A little exploration of the work of np.argmax()

Since I don't understand the work of argmax() very well, I did the following test and finally figured it out.

i = 0	#测试集第一张图片
plt.imshow(x_test[i],cmap=plt.cm.binary)
plt.show()
print(np.argmax(predictions[i]))    # argmax输出的是最大数的索引,predicts[i]是十个分类的权值
print((predictions[i]))             # 比如predicts[0]最大的权值是第八个数,索引为7,故预测的数字为7

The output is as follows:

plt.show() outputs the first picture in the test set (handwritten number 7); the first print outputs 7; the second print outputs the value of Predictions[0].
By observing these 10 weights, it can be seen that the eighth weight is the largest, reaching 1; and argmax() is the index of the largest number of outputs, and the index of the eighth number with the largest weight is 7, so the first print output It is 7, that is, it is predicted that the handwritten number of the picture is 7.
Therefore, we can clearly understand the role and usage of argmax.

Guess you like

Origin blog.csdn.net/weixin_45371989/article/details/104581865