LeNet-5实战minist——搭建卷积网络模型

leNet-5模型

leNet-5是一个非常成功的神经网络模型。
基于LeNet-5的手写数字识别系统在90年代被美国很多银行使用,用来识别支票上面的手写数字。
LeNet-5共有7层。
在这里插入图片描述

mnist数据集简介

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。
从tensorflow官网: link.可以查看到该数据集的调用命令。除了此数据集,tensorflow还提供了一些别的数据集,需要时可以自己上官网查看数据集的下载命令,此外也可在官网查询某些模型的使用方法。对于mnist数据集来说,它存放在tf.keras->datasets->mnist->load_data中:
调用指令

本文采用tensorflow-GPU版本进行模型搭建与训练

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
tf.__version__

‘2.2.0’

引入数据

(x_train, y_train), (x_test, y_test)=tf.keras.datasets.mnist.load_data()
def preprocess(x,y):
    x=tf.cast(x,dtype=tf.float32)/255.
    x=tf.reshape(x,[-1,28,28,1])
    y=tf.one_hot(y,depth=10)
    return x,y
(x_train, y_train), (x_test, y_test)=tf.keras.datasets.mnist.load_data() 
train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_db=train_db.shuffle(10000)
train_db=train_db.batch(128)
train_db=train_db.map(preprocess)
test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db=test_db.shuffle(10000)
test_db=test_db.batch(128)
test_db=test_db.map(preprocess)

建立模型

batch=32

model=tf.keras.Sequential([\
                           keras.layers.Conv2D(6,3),
                           keras.layers.MaxPooling2D(pool_size=2,strides=2),
                           keras.layers.ReLU(),
                           keras.layers.Conv2D(16,3),
                           keras.layers.MaxPooling2D(pool_size=2,strides=2),
                           keras.layers.ReLU(),
                           keras.layers.Flatten(),
                           keras.layers.Dense(120,activation='relu'),
                           keras.layers.Dense(84,activation='relu'),
                           keras.layers.Dense(10,activation='softmax')
                          ])
model.build(input_shape=(batch,28,28,1))
model.summary()

在这里插入图片描述

模型训练

model.compile(optimizer=keras.optimizers.Adam(),
             loss=keras.losses.CategoricalCrossentropy(),
             metrics=['accuracy'])
history=model.fit(train_db,epochs=50)

在这里插入图片描述

model.evaluate(test_db)

在这里插入图片描述

fig=plt.figure()
plt.plot(history.history['loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epochs')
plt.show()

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_48994268/article/details/109622807
今日推荐