TensorFlow2.0 combate de reconocimiento de dígitos manuscritos de red neuronal completamente conectada

1. Configurar archivos de biblioteca

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np

2. Cargue el conjunto de datos

El conjunto de datos de carga de tensorflow es más conveniente, los siguientes son los métodos de carga en línea y local:

# 下面一行是在线加载方式
# mnist = tf.keras.datasets.mnist
# 下面两行是加载本地的数据集
datapath  = r'E:\Pycharm\project\project_TF\.idea\data\mnist.npz'
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(datapath)

x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)   #归一化

3. Establecer un modelo de red completamente conectado

model = tf.keras.Sequential([ # 3 个非线性层的嵌套模型
    tf.keras.layers.Flatten(),  #将多维数据打平
    tf.keras.layers.Dense(784, activation='relu'),	# 128也行
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')	# softmax分类
])
# 打印模型
model.build((None,784,1))	# 这里要先build,告诉模型数据输入格式
print(model.summary())

La estructura de la red impresa es la siguiente:

4. Recopilación y capacitación de modelos

model.compile(optimizer='adam',	# 优化器
              loss='sparse_categorical_crossentropy',	# 交叉熵损失函数
              metrics=['accuracy'])	# 标签
# 训练模型
model.fit(x_train, y_train, epochs=10,verbose=1) # verbose为1表示显示训练过程

5. Prueba de modelos

Hay dos formas de dar precisión. Con base en esto, puede comprender la diferencia entre model.evaluate () y model.predict () y profundizar su comprensión.
El primero (recomendado):

#这里是测试模型
val_loss, val_acc = model.evaluate(x_test, y_test) # model.evaluate是输出计算的损失和精确度
print('First test Loss:{:.6f}'.format(val_loss)

El segundo tipo (para profundizar la comprensión):

#测试模型方式二
acc_correct = 0
predictions = model.predict(x_test)     # model.perdict是输出预测结果
for i in range(len(x_test)):
    if (np.argmax(predictions[i]) == y_test[i]):    # argmax是取最大数的索引,放这里是最可能的预测结果
        acc_correct += 1
print('Second test accuracy:{:.6f}'.format(acc_correct*1.0/(len(x_test))))

En este punto, se completa el programa de red neuronal completamente conectado.

6. Los resultados del programa

La precisión de la concentración de entrenamiento alcanzó 0,9967 y la precisión de la concentración de prueba alcanzó 0,9787. El uso de redes neuronales convolucionales puede aumentar aún más la precisión.

7. Una pequeña exploración del trabajo de np.argmax ()

Como no entiendo muy bien el trabajo de argmax (), hice la siguiente prueba y finalmente lo descubrí.

i = 0	#测试集第一张图片
plt.imshow(x_test[i],cmap=plt.cm.binary)
plt.show()
print(np.argmax(predictions[i]))    # argmax输出的是最大数的索引,predicts[i]是十个分类的权值
print((predictions[i]))             # 比如predicts[0]最大的权值是第八个数,索引为7,故预测的数字为7

El resultado es el siguiente:

plt.show () genera la primera imagen en el conjunto de prueba (número escrito a mano 7), la primera impresión genera 7, la segunda impresión genera el valor de Predicciones [0].
Al observar estos 10 pesos, se puede ver que el octavo peso es el más grande y llega a 1; y argmax () es el índice del mayor número de salidas, el índice del octavo número con el mayor peso es 7, por lo que la primera salida impresa Es 7, es decir, se predice que el número manuscrito de la imagen es 7.
Por lo tanto, podemos comprender claramente el papel y el uso de argmax.

Supongo que te gusta

Origin blog.csdn.net/weixin_45371989/article/details/104581865
Recomendado
Clasificación