基于python3.6 tensorflow的mnist全连接神经网络手写训练

一、说明

最近也是刚开始学习tensorflow,一步一步来,把刚写完的mnist代码贴上来,以备不时之需吧。

这个项目是基于全连接神经网络的,两个隐藏层(可扩展,已经封装)784->256->49->10,支持中断后继续训练

二、测试环境

python 版本:3.6

tensorflow版本:1.6.0

系统环境:windows 10 64位  + Visual Studio code

三、资料

mnist官网:http://yann.lecun.com/exdb/mnist/

项目代码(内含mnist训练集):链接:https://pan.baidu.com/s/1ESuQ_JcXxbqIyTUF-W61Cw  提取码:ylcj 

四、源代码

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import tensorflow.examples.tutorials.mnist.input_data as input_data

# 建立模型保存目录
ckpt_dir = "./ckpt_dir/"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# 超参数
learning_rate = 0.01
train_epoch = 100
batch_size = 50
hide_net1 = 256
hide_net2 = 49
#hide_net3 = 32
train_batch = int(mnist.train.num_examples / batch_size)

x = tf.placeholder(tf.float32, [None, 784], name="X")
y = tf.placeholder(tf.float32, [None, 10], name="Y")

# # 第一层隐藏层
# w1 = tf.Variable(tf.random_normal([784, hide_net1], stddev=1.0), name="w1")
# b1 = tf.Variable(tf.zeros(hide_net1), name="b1")

# y1 = tf.nn.relu(tf.matmul(x, w1) + b1)

# # 第二隐藏层
# w2 = tf.Variable(tf.random_normal([hide_net1, hide_net2], stddev=1.0), name="w2")
# b2 = tf.Variable(tf.zeros(hide_net2), name="b2")

# y2 = tf.nn.relu(tf.matmul(y1, w2) + b2)

# # # 第三隐藏层
# # w3 = tf.Variable(tf.random_normal([hide_net2, hide_net3], stddev=1.0),name="w3")
# # b3 = tf.Variable(tf.zeros(hide_net3), name="b3")

# # y3 = tf.nn.relu(tf.matmul(y2, w3) + b3)


# w = tf.Variable(tf.random_normal([hide_net2, 10], stddev=1.0), name="w")
# b = tf.Variable(tf.zeros(10), name="b")

# forward = tf.matmul(y2, w) + b
# pred = tf.nn.softmax(forward)

def fcn_layer(inputs, input_shape, output_shape, activate_function=None):
    ws = tf.Variable(tf.random_normal([input_shape, output_shape], stddev=1.0))
    bs = tf.Variable(tf.zeros(output_shape))

    if activate_function:
        return activate_function(tf.matmul(inputs, ws) + bs)
    else:
        return tf.matmul(inputs, ws) + bs

y1 = fcn_layer(x, 784, hide_net1, tf.nn.relu)
y2 = fcn_layer(y1, hide_net1, hide_net2, tf.nn.relu)
forward = fcn_layer(y2, hide_net2, 10, None)
pred = tf.nn.softmax(forward)

# 定义模型的对象
saver = tf.train.Saver()

#loss_function = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))  # 由于log(pred)可能出现log(0)导致错误,故不适用此方法
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=forward, labels=y))

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred, axis=1), tf.argmax(y, axis=1)), tf.float32), reduction_indices=0)

accuracy_list = []
loss_list=[]

with tf.Session() as sess:
    # 初始化变量
    sess.run(tf.global_variables_initializer())

    epoch_last=0

    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Restore the model from ", ckpt.model_checkpoint_path)
        epoch_last = int(ckpt.model_checkpoint_path[-11:-5])

        list_file = open("./ckpt_dir/list.txt", "r")
        accuracy_list_str, loss_list_str = list_file.readline().split("@")
        accuracy_list = json.loads(accuracy_list_str)
        loss_list = json.loads(loss_list_str)
        list_file.close()

    if epoch_last < train_epoch:
        for epoch in range(epoch_last,train_epoch):
            for batch in range(train_batch):
                xs, ys = mnist.train.next_batch(batch_size)
                sess.run(optimizer, feed_dict={x: xs, y: ys})

            loss_temp, accuracy_temp = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images,y: mnist.validation.labels})
            print("训练轮数:", epoch + 1, " loss=", loss_temp, " accuracy=", accuracy_temp)

            accuracy_list.append(accuracy_temp)
            loss_list.append(loss_temp)

            if (epoch+1)%5==0:
                saver.save(sess, os.path.join(ckpt_dir, "mnist_h256_model_{:06d}.ckpt".format(epoch + 1)))
                print("mnist_h256_model_{:06d}.ckpt is saved".format(epoch + 1))
                list_file = open("./ckpt_dir/list.txt", "w")
                list_file.writelines(str(accuracy_list)+"@"+str(loss_list))
                list_file.close()

        if train_epoch%5 !=0:
            saver.save(sess, os.path.join(ckpt_dir, "mnist_h256_model_{:06d}.ckpt".format(train_epoch)))
        print("Model saved!")

plt.rcParams['font.sans-serif'] = ['FangSong']  # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False  # 解决保存图像时'-'显示为方块的问题

loss_img=plt.subplot(1, 2, 1)
plt.plot(loss_list)
loss_img.set_title("损失变化曲线")

accuracy_img = plt.subplot(1, 2, 2)
plt.plot(accuracy_list)
accuracy_img.set_title("精确度变化曲线")

plt.show()

猜你喜欢

转载自blog.csdn.net/qq_36290650/article/details/104392210