TensorFlow实战框架Chp10--利用TFLearn在MNIST数据集上实现LeNet-5模型

  • 利用TFLearn在MNIST数据集上实现LeNet-5模型
# -*- coding: utf-8 -*-
"""
Created on Mon Jul  9 22:37:34 2018

@author: muli
"""

import tflearn
from tflearn.layers.core import input_data, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression

import tflearn.datasets.mnist as mnist


trainX, trainY, testX, testY = mnist.load_data(
    data_dir="./datasets/MNIST_data", one_hot=True)
# 将图像数据resize成卷积卷积神经网络输入的格式
# -1:由函数自动计算,表示batch的大小
trainX = trainX.reshape([-1, 28, 28, 1])
testX = testX.reshape([-1, 28, 28, 1])

# 构建神经网络。
net = input_data(shape=[None, 28, 28, 1], name='input')
net = conv_2d(net, 32, 5, activation='relu')
net = max_pool_2d(net, 2)
net = conv_2d(net, 64, 5, activation='relu')
net = max_pool_2d(net, 2)
net = fully_connected(net, 500, activation='relu')
net = fully_connected(net, 10, activation='softmax')
# 定义学习任务。指定优化器为sgd,学习率为0.01,损失函数为交叉熵。
net = regression(net, optimizer='sgd', learning_rate=0.01,
                 loss='categorical_crossentropy')


# 通过定义的网络结构训练模型,并在指定的验证数据上验证模型的效果。
model = tflearn.DNN(net, tensorboard_verbose=0)
model.fit(trainX, trainY, n_epoch=3,
          validation_set=([testX, testY]),
          show_metric=True)


猜你喜欢

转载自blog.csdn.net/mr_muli/article/details/80982508