[ MOOC课程学习 ] 人工智能实践:Tensorflow笔记_CH5_2/3 实现手写体 mnist 数据集的识别任务

实现手写体 mnist 数据集的识别任务,共分为三个模块文件,分别是:

描述网络结构的前向传播过程文件(mnist_forward.py)
描述网络参数优化方法的反向传播过程文件( mnist_backward.py )
验证模型准确率的测试过程文件(mnist_test.py)。

  1. 描述网络结构的前向传播过程文件(mnist_forward.py)

    
    #coding:utf-8
    
    import tensorflow as tf
    
    INPUT_NODE = 784
    OUTPUT_NODE = 10
    LAYER1_NODE = 500
    
    def get_weight(shape, regularizer):
        w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
        if regularizer != None:
            tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
        return w
    
    def get_bias(shape):
        b = tf.Variable(tf.zeros(shape))
        return b
    
    def forward(x, regularizer):
        w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)
        b1 = get_bias([LAYER1_NODE])
        y1 = tf.nn.relu(tf.matmul(x, w1) + b1)
    
        w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)
        b2 = get_bias([OUTPUT_NODE])
        y = tf.matmul(y1, w2) + b2
    
        return y
    
  2. 描述网络参数优化方法的反向传播过程文件( mnist_backward.py )

    
    #coding:utf-8
    
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import os
    import mnist_forward
    
    
    BATCH_SIZE = 200
    REGULARIZER = 0.0001
    LR = 0.1
    LR_DECAY_RATE = 0.99
    EMA_DECAY = 0.99
    STEPS = 50000
    MODEL_SAVE_PATH = './model/'
    MODEL_NAME = 'mnist_model'
    
    def backward(mnist):
    
        x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
        y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
        y = mnist_forward.forward(x, REGULARIZER)
    
        global_step = tf.Variable(0, trainable=False)
    
        lr = tf.train.exponential_decay(
            learning_rate = LR,
            global_step = global_step,
            decay_steps = mnist.train.num_examples / BATCH_SIZE,
            decay_rate = LR_DECAY_RATE,
            staircase = True
        )
    
        ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_, 1), logits=y)
        cem = tf.reduce_mean(ce)
        loss = cem + tf.add_n(tf.get_collection('losses'))
        train_step = tf.train.GradientDescentOptimizer(lr).minimize(loss, global_step=global_step)
    
        ema = tf.train.ExponentialMovingAverage(
            decay = EMA_DECAY,
            num_updates = global_step
        )
        ema_op = ema.apply(tf.trainable_variables())
    
        with tf.control_dependencies([train_step, ema_op]):
            train_op = tf.no_op('train')
    
        saver = tf.train.Saver()
    
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            for i in range(STEPS):
                xs, ys = mnist.train.next_batch(BATCH_SIZE)
                _, loss_v, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
                if i % 1000 == 0:
                    print('After %d training steps, loss on training batch is %g.' % (step, loss_v))
                    saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
    
    def main():
        mnist = input_data.read_data_sets('./data/', one_hot=True)
        backward(mnist)
    
    if __name__ == '__main__':
        main()
    
    

    model 文件夹:
    model 文件夹

  3. 验证模型准确率的测试过程文件(mnist_test.py)。

    
    #coding:utf-8
    
    import time
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import mnist_forward
    import mnist_backward
    TEST_INTERVAL_SECS = 5
    
    def test(mnist):
        with tf.Graph().as_default() as g:
            x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
            y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
            y = mnist_forward.forward(x, None)
    
            # 实例化可还原滑动平均的 saver
            # 这样所有参数在会话中被加载时会被赋值为各自的滑动平均值
            ema = tf.train.ExponentialMovingAverage(mnist_backward.EMA_DECAY)
            ema_restore = ema.variables_to_restore()
            saver = tf.train.Saver(ema_restore)
    
            # 计算正确率
            correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
            while True:
                with tf.Session() as sess:
                    # 加载训练好的模型
                    ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                    # 如果已有 ckpt 模型则恢复
                    if ckpt and ckpt.model_checkpoint_path:
                        # 恢复会话
                        saver.restore(sess, ckpt.model_checkpoint_path)
                        # 恢复轮数
                        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                        # 计算正确率
                        accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
                        print('After %s training steps, test accuracy = %g' % (global_step, accuracy_score))
                    else:
                        print('No checkpoint file found.')
                        return
                time.sleep(TEST_INTERVAL_SECS)
    
    def main():
        mnist = input_data.read_data_sets('./data/', one_hot=True)
        test(mnist)
    
    if __name__ == '__main__':
        main()

猜你喜欢

转载自blog.csdn.net/ranmw1129/article/details/81099377