Reconocimiento de dígitos manuscritos de Keras
Se puede decir que el reconocimiento de dígitos escritos a mano es el "hola mundo" en el campo del aprendizaje automático.
Para empezar, este puede ser un buen caso. El programa se ejecuta en jupyter, por lo que hay muchos resultados intermedios El libro de referencia es "Python Deep Learning".
import keras
from keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np
Usando el backend de TensorFlow.
#获取mnist数据集(如果本地没有,会连网下载)
#训练集image格式(60000,28,28)
#训练集label格式(60000,)
#测试集image格式(10000,28,28)
#测试集label格式(10000,)
(train_img,train_label),(test_img,test_label)=mnist.load_data()
#拷贝一份测试数据以备后续预测
test_predict=test_img.copy()
#分离训练集和验证集
val_img=train_img[:30000]
train_img=train_img[30000:]
val_label=train_label[:30000]
train_label=train_label[30000:]
#将图片展平+归一化(0~1)
train_img=train_img.reshape((-1,28*28)).astype('float64')
train_img=train_img/255
#验证集
val_img=val_img.reshape((-1,28*28)).astype('float64')
val_img=val_img/255
test_img=test_img.reshape((10000,28*28)).astype('float64')
test_img=test_img/255
#标签one-hot表示
train_label=keras.utils.to_categorical(train_label)
val_label=keras.utils.to_categorical(val_label) #验证集的标签
test_label=keras.utils.to_categorical(test_label)
#显示一张手写数字
plt.imshow(test_predict[0])
#搭建网络
#全连接(relu)+全连接(softmax)
net=keras.models.Sequential()
net.add(keras.layers.Dense(512,activation='relu',input_shape=(28*28,)))
net.add(keras.layers.Dense(10,activation='softmax'))
#定义优化器和损失函数
net.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
#fit训练+验证,最终结果保存到his
his=net.fit(train_img,
train_label,
batch_size=128,
epochs=10,
validation_data=(val_img,val_label))
Entrene en 30000 muestras, valide en 30000 muestras
Época 1/10
30000/30000 [] - 5s 157us / paso - pérdida: 0.3495 - acc: 0.8992 - val_loss: 0.2031 - val_acc: 0.9410
Epoch 2/10
30000/30000 [] - 4s 131us / paso - pérdida: 0.1550 - acc: 0.9543 - val_loss: 0.1517 - val_acc: 0.9552
Epoch 3/10
30000/30000 [] - 3s 114us / paso - pérdida: 0.1035 - acc: 0.9694 - val_loss: 0.1174 - val_acc: 0.9647
Epoch 4/10
30000/30000 [] - 3s 113us / paso - pérdida: 0.0744 - acc: 0.9774 - val_loss: 0.1060 - val_acc: 0.9682
Epoch 5/10
30000/30000 [] - 3s 109us / paso - pérdida: 0.0532 - acc: 0.9841 - val_loss: 0.0975 - val_acc: 0.9715
Epoch 6/10
30000/30000 [] - 3s 103us / paso - pérdida: 0.0404 - acc: 0.9878 - val_loss: 0.1001 - val_acc: 0.9709
Epoch 7/10
30000/30000 [] - 4s 122us / paso - pérdida: 0.0301 - acc: 0.9915 - val_loss: 0.0979 - val_acc: 0.9724
Epoch 8/10
30000/30000 [] - 4s 130us / paso - pérdida: 0.0223 - acc: 0.9939 - val_loss: 0.1012 - val_acc: 0.9725
Epoch 9/10
30000/30000 [] - 3s 109us / paso - pérdida: 0.0164 - acc: 0.9957 - val_loss: 0.0963 - val_acc: 0.9744
Epoch 10/10
30000/30000 [] - 4s 118us / paso - pérdida: 0.0125 - acc: 0.9971 - val_loss: 0.0986 - val_acc: 0.9737
#提取his中的信息
his_dict=his.history
loss=his_dict['loss']
val_loss=his_dict['val_loss']
acc=his_dict['acc']
val_acc=his_dict['val_acc']
epoch=range(1,len(loss)+1)
#画图
plt.plot(epoch,loss,'k',label='train_loss')
plt.plot(epoch,val_loss,'b',label='validation_loss')
plt.plot(epoch,acc,'r',label='accuracy')
plt.plot(epoch,val_acc,'g',label='accuracy')
plt.title('train ans valid')
plt.xlabel('epoch')
plt.ylabel('loss/acc')
plt.legend()
plt.show()
#进行测试
test_loss,test_acu=net.evaluate(test_img,test_label)
10000/10000 [==============================] - 1 s 56us / paso
print('test_loss=',test_loss,'\ntest_accuracy=',test_acu)
test_loss = 0.0941627221799965
test_accuracy = 0.9742
#图片的单独预测
prediction=net.predict(test_img)
plt.imshow(test_predict[11])
plt.show()
print('result:',np.argmax(prediction[11]))
resultado: 6
El resultado final es correcto y la predicción completa.