实现手写体 mnist 数据集的识别任务,共分为三个模块文件,分别是:
描述网络结构的前向传播过程文件(mnist_forward.py)
描述网络参数优化方法的反向传播过程文件( mnist_backward.py )
验证模型准确率的测试过程文件(mnist_test.py)。
描述网络结构的前向传播过程文件(mnist_forward.py)
#coding:utf-8 import tensorflow as tf INPUT_NODE = 784 OUTPUT_NODE = 10 LAYER1_NODE = 500 def get_weight(shape, regularizer): w = tf.Variable(tf.truncated_normal(shape, stddev=0.1)) if regularizer != None: tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w)) return w def get_bias(shape): b = tf.Variable(tf.zeros(shape)) return b def forward(x, regularizer): w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer) b1 = get_bias([LAYER1_NODE]) y1 = tf.nn.relu(tf.matmul(x, w1) + b1) w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer) b2 = get_bias([OUTPUT_NODE]) y = tf.matmul(y1, w2) + b2 return y
描述网络参数优化方法的反向传播过程文件( mnist_backward.py )
#coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os import mnist_forward BATCH_SIZE = 200 REGULARIZER = 0.0001 LR = 0.1 LR_DECAY_RATE = 0.99 EMA_DECAY = 0.99 STEPS = 50000 MODEL_SAVE_PATH = './model/' MODEL_NAME = 'mnist_model' def backward(mnist): x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE]) y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE]) y = mnist_forward.forward(x, REGULARIZER) global_step = tf.Variable(0, trainable=False) lr = tf.train.exponential_decay( learning_rate = LR, global_step = global_step, decay_steps = mnist.train.num_examples / BATCH_SIZE, decay_rate = LR_DECAY_RATE, staircase = True ) ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_, 1), logits=y) cem = tf.reduce_mean(ce) loss = cem + tf.add_n(tf.get_collection('losses')) train_step = tf.train.GradientDescentOptimizer(lr).minimize(loss, global_step=global_step) ema = tf.train.ExponentialMovingAverage( decay = EMA_DECAY, num_updates = global_step ) ema_op = ema.apply(tf.trainable_variables()) with tf.control_dependencies([train_step, ema_op]): train_op = tf.no_op('train') saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(STEPS): xs, ys = mnist.train.next_batch(BATCH_SIZE) _, loss_v, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys}) if i % 1000 == 0: print('After %d training steps, loss on training batch is %g.' % (step, loss_v)) saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) def main(): mnist = input_data.read_data_sets('./data/', one_hot=True) backward(mnist) if __name__ == '__main__': main()
model 文件夹:
验证模型准确率的测试过程文件(mnist_test.py)。
#coding:utf-8 import time import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import mnist_forward import mnist_backward TEST_INTERVAL_SECS = 5 def test(mnist): with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE]) y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE]) y = mnist_forward.forward(x, None) # 实例化可还原滑动平均的 saver # 这样所有参数在会话中被加载时会被赋值为各自的滑动平均值 ema = tf.train.ExponentialMovingAverage(mnist_backward.EMA_DECAY) ema_restore = ema.variables_to_restore() saver = tf.train.Saver(ema_restore) # 计算正确率 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) while True: with tf.Session() as sess: # 加载训练好的模型 ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH) # 如果已有 ckpt 模型则恢复 if ckpt and ckpt.model_checkpoint_path: # 恢复会话 saver.restore(sess, ckpt.model_checkpoint_path) # 恢复轮数 global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] # 计算正确率 accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('After %s training steps, test accuracy = %g' % (global_step, accuracy_score)) else: print('No checkpoint file found.') return time.sleep(TEST_INTERVAL_SECS) def main(): mnist = input_data.read_data_sets('./data/', one_hot=True) test(mnist) if __name__ == '__main__': main()