TensorFlow2.0 全连接神经网络 手写数字识别实战

1. 配置库文件

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np

2. 加载数据集

tensorflow加载数据集比较便捷,下面是在线加载和本地加载方式:

# 下面一行是在线加载方式
# mnist = tf.keras.datasets.mnist
# 下面两行是加载本地的数据集
datapath  = r'E:\Pycharm\project\project_TF\.idea\data\mnist.npz'
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(datapath)

x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)   #归一化

3. 建立全连接网络模型

model = tf.keras.Sequential([ # 3 个非线性层的嵌套模型
    tf.keras.layers.Flatten(),  #将多维数据打平
    tf.keras.layers.Dense(784, activation='relu'),	# 128也行
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')	# softmax分类
])
# 打印模型
model.build((None,784,1))	# 这里要先build,告诉模型数据输入格式
print(model.summary())

打印出的网络结构如下:

4.模型编译与训练

model.compile(optimizer='adam',	# 优化器
              loss='sparse_categorical_crossentropy',	# 交叉熵损失函数
              metrics=['accuracy'])	# 标签
# 训练模型
model.fit(x_train, y_train, epochs=10,verbose=1) # verbose为1表示显示训练过程

5. 模型测试

这里提供两种方式给出精确度。可以据此了解model.evaluate()与model.predict()的区别,加深了解。
第一种(推荐):

#这里是测试模型
val_loss, val_acc = model.evaluate(x_test, y_test) # model.evaluate是输出计算的损失和精确度
print('First test Loss:{:.6f}'.format(val_loss)

第二种(加深理解):

#测试模型方式二
acc_correct = 0
predictions = model.predict(x_test)     # model.perdict是输出预测结果
for i in range(len(x_test)):
    if (np.argmax(predictions[i]) == y_test[i]):    # argmax是取最大数的索引,放这里是最可能的预测结果
        acc_correct += 1
print('Second test accuracy:{:.6f}'.format(acc_correct*1.0/(len(x_test))))

至此,完整的全连接神经网络程序完成。

6.程序运行结果

训练集中精确度达到了0.9967,测试集中精确度达到了0.9787。而采用卷积神经网络能够进一步增加精确度。

7.对np.argmax()工作的一点探索

由于对argmax()工作不是很理解,故做了以下测试,最终弄明白了。

i = 0	#测试集第一张图片
plt.imshow(x_test[i],cmap=plt.cm.binary)
plt.show()
print(np.argmax(predictions[i]))    # argmax输出的是最大数的索引,predicts[i]是十个分类的权值
print((predictions[i]))             # 比如predicts[0]最大的权值是第八个数,索引为7,故预测的数字为7

输出结果如下:

plt.show()输出的是测试集第一个图片(手写数字7);第一个print输出7;第二个print输出了Predictions[0]的值。
通过观察这10个权值可以看出,第八个权值最大,到了1;而argmax()是输出最大数的索引,权值最大的第八个数索引为7,故第一个print输出是7,即预测出了图片的手写数字是7。
因此,能够很清楚地了解argmax的作用及用法。

猜你喜欢

转载自blog.csdn.net/weixin_45371989/article/details/104581865