一、说明
最近也是刚开始学习tensorflow,一步一步来,把刚写完的mnist代码贴上来,以备不时之需吧。
这个项目是基于全连接神经网络的,两个隐藏层(可扩展,已经封装)784->256->49->10,支持中断后继续训练
二、测试环境
python 版本:3.6
tensorflow版本:1.6.0
系统环境:windows 10 64位 + Visual Studio code
三、资料
mnist官网:http://yann.lecun.com/exdb/mnist/
项目代码(内含mnist训练集):链接:https://pan.baidu.com/s/1ESuQ_JcXxbqIyTUF-W61Cw 提取码:ylcj
四、源代码
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import tensorflow.examples.tutorials.mnist.input_data as input_data
# 建立模型保存目录
ckpt_dir = "./ckpt_dir/"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# 超参数
learning_rate = 0.01
train_epoch = 100
batch_size = 50
hide_net1 = 256
hide_net2 = 49
#hide_net3 = 32
train_batch = int(mnist.train.num_examples / batch_size)
x = tf.placeholder(tf.float32, [None, 784], name="X")
y = tf.placeholder(tf.float32, [None, 10], name="Y")
# # 第一层隐藏层
# w1 = tf.Variable(tf.random_normal([784, hide_net1], stddev=1.0), name="w1")
# b1 = tf.Variable(tf.zeros(hide_net1), name="b1")
# y1 = tf.nn.relu(tf.matmul(x, w1) + b1)
# # 第二隐藏层
# w2 = tf.Variable(tf.random_normal([hide_net1, hide_net2], stddev=1.0), name="w2")
# b2 = tf.Variable(tf.zeros(hide_net2), name="b2")
# y2 = tf.nn.relu(tf.matmul(y1, w2) + b2)
# # # 第三隐藏层
# # w3 = tf.Variable(tf.random_normal([hide_net2, hide_net3], stddev=1.0),name="w3")
# # b3 = tf.Variable(tf.zeros(hide_net3), name="b3")
# # y3 = tf.nn.relu(tf.matmul(y2, w3) + b3)
# w = tf.Variable(tf.random_normal([hide_net2, 10], stddev=1.0), name="w")
# b = tf.Variable(tf.zeros(10), name="b")
# forward = tf.matmul(y2, w) + b
# pred = tf.nn.softmax(forward)
def fcn_layer(inputs, input_shape, output_shape, activate_function=None):
ws = tf.Variable(tf.random_normal([input_shape, output_shape], stddev=1.0))
bs = tf.Variable(tf.zeros(output_shape))
if activate_function:
return activate_function(tf.matmul(inputs, ws) + bs)
else:
return tf.matmul(inputs, ws) + bs
y1 = fcn_layer(x, 784, hide_net1, tf.nn.relu)
y2 = fcn_layer(y1, hide_net1, hide_net2, tf.nn.relu)
forward = fcn_layer(y2, hide_net2, 10, None)
pred = tf.nn.softmax(forward)
# 定义模型的对象
saver = tf.train.Saver()
#loss_function = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1)) # 由于log(pred)可能出现log(0)导致错误,故不适用此方法
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=forward, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred, axis=1), tf.argmax(y, axis=1)), tf.float32), reduction_indices=0)
accuracy_list = []
loss_list=[]
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
epoch_last=0
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
print("Restore the model from ", ckpt.model_checkpoint_path)
epoch_last = int(ckpt.model_checkpoint_path[-11:-5])
list_file = open("./ckpt_dir/list.txt", "r")
accuracy_list_str, loss_list_str = list_file.readline().split("@")
accuracy_list = json.loads(accuracy_list_str)
loss_list = json.loads(loss_list_str)
list_file.close()
if epoch_last < train_epoch:
for epoch in range(epoch_last,train_epoch):
for batch in range(train_batch):
xs, ys = mnist.train.next_batch(batch_size)
sess.run(optimizer, feed_dict={x: xs, y: ys})
loss_temp, accuracy_temp = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images,y: mnist.validation.labels})
print("训练轮数:", epoch + 1, " loss=", loss_temp, " accuracy=", accuracy_temp)
accuracy_list.append(accuracy_temp)
loss_list.append(loss_temp)
if (epoch+1)%5==0:
saver.save(sess, os.path.join(ckpt_dir, "mnist_h256_model_{:06d}.ckpt".format(epoch + 1)))
print("mnist_h256_model_{:06d}.ckpt is saved".format(epoch + 1))
list_file = open("./ckpt_dir/list.txt", "w")
list_file.writelines(str(accuracy_list)+"@"+str(loss_list))
list_file.close()
if train_epoch%5 !=0:
saver.save(sess, os.path.join(ckpt_dir, "mnist_h256_model_{:06d}.ckpt".format(train_epoch)))
print("Model saved!")
plt.rcParams['font.sans-serif'] = ['FangSong'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False # 解决保存图像时'-'显示为方块的问题
loss_img=plt.subplot(1, 2, 1)
plt.plot(loss_list)
loss_img.set_title("损失变化曲线")
accuracy_img = plt.subplot(1, 2, 2)
plt.plot(accuracy_list)
accuracy_img.set_title("精确度变化曲线")
plt.show()