TensorFlow北大公开课学习笔记-5.3手写数字识别准确率输出

一,mnist_forward.py

#coding:utf-8
#0导入模块,生成模拟数据集
import tensorflow as tf

INPUT_NODE=784  #输入层节点
OUTPUT_NODE=10  #输出层节点
LAYER1_NODE=500 #隐藏层节点

#定义神经网络的输入、参数和输出,定义前向传播过程
def get_weight(shape,regularizer):
    w=tf.Variable(tf.truncated_normal(shape,stddev=0.1))
    if regularizer!=None:tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

def get_bias(shape):
    b=tf.Variable(tf.zeros(shape))
    return b

def forward(x,regularizer):
    w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)
    b1 = get_bias([LAYER1_NODE])
    y1 = tf.nn.relu(tf.matmul(x, w1) + b1)

    w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)
    b2 = get_bias([OUTPUT_NODE])
    y = tf.matmul(y1, w2) + b2  # 输出层不过激活
    return y

二,mnist_backward.py

#coding:utf-8
#0导入模块,生成模拟数据集
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

BATCH_SIZE=200
LEARNING_RATE_BASE=0.1    #最初学习率
LEARNING_RATE_DECAY=0.99  #学习率衰减率
REGULARIZER=0.0001   #正则化系数
STEPS=50000          #训练多少轮
MOVING_AVERAGE_DECAY=0.99   #滑动平均衰减率
MODEL_SAVE_PATH="./model"   #模型保存路径
MODEL_NAME="mnist_model"    #模型名称

def backward(mnist):
    x = tf.placeholder(tf.float32, [None,mnist_forward.INPUT_NODE])
    y_ = tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
    y=mnist_forward.forward(x,REGULARIZER)

    # 运行了几轮BATCH_SIZE的计数器,初值给0,设为不被训练
    global_step = tf.Variable(0, trainable=False)

    # 定义损失函数
    ce=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
    cem=tf.reduce_mean(ce)
    loss=cem+tf.add_n(tf.get_collection('losses'))

    # 定义指数下降学习率
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples/BATCH_SIZE,
        LEARNING_RATE_DECAY,
        staircase=True)

    # 定义反向传播方法;不含正则化
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
    ema=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
    ema_op=ema.apply(tf.trainable_variables())
    with tf.control_dependencies([train_step,ema_op]):
        train_op=tf.no_op(name='train')

    saver=tf.train.Saver()

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        for i in range(STEPS):
            xs,ys=mnist.train.next_batch(BATCH_SIZE)
            _,loss_value,step=sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})
            if i % 1000 == 0:
                print("After %d training step(s),loss on training batch is %g."%(step,loss_value))
                saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)


def main():
    mnist=input_data.read_data_sets('./data/',one_hot=True)
    backward(mnist)

if __name__=='__main__':
    main()

三,

#coding:utf-8
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward
TEST_INTERVAL_SECS=5    #程序循环的间隔时间  5s

def test(mnist):   #读入mnist数据集
    with tf.Graph().as_default() as g:  #复现计算图
        #初始化
        x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
        y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
        #前向传播计算y的值
        y = mnist_forward.forward(x, None)

        #实例化带滑动平均的saver对象,这样所有参数被加载时,都会被赋值为各自的滑动平均值
        ema=tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        ema_restore=ema.variables_to_restore()
        saver=tf.train.Saver(ema_restore)

        #计算准确率
        correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
        accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

        while True:
            with tf.Session() as sess:
                #加载ckpt,即将滑动平均值赋值给各个参数
                ckpt=tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                #判断有没有模型,如果有,先恢复模型到当前会话
                if ckpt and ckpt.model_checkpoint_path:
                    #先恢复模型到当前会话
                    saver.restore(sess,ckpt.model_checkpoint_path)
                    #恢复global_step值
                    global_step=ckpt.model_checkpoint_path.split('/')[-1].split("-")[-1]
                    #执行准确率计算
                    accuracy_score=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
                    print("After %s training step(s), test accuracy=%g"%(global_step,accuracy_score))
                else:
                    print("No checkpoint file found")  #未找到模型
                    return
            time.sleep(TEST_INTERVAL_SECS)

def main():
    mnist=input_data.read_data_sets('./data/',one_hot=True) #读入数据集
    test(mnist)  #执行test函数

if __name__=='__main__':
    main()

猜你喜欢

转载自blog.csdn.net/sxlsxl119/article/details/81430995