Tensorflowr学习小结:Mnist手写数字识别

Mnist手写数字识别是利用神经网络来对手写数字0-9进行分类。简单来说这就是一个使用了监督学习的十分类问题。这里将使用前向传播,训练过程,检测正确率三个文件,并通过单隐蔽层的全连接神经网络来解决。笔者主要通过这个例子对常用的神经网络优化方法与模型持久化方法做个总结。

在进入问题之前先了解一下Mnist数据集,每个手写数字是28*28像素的图片。作为机器学习中的”hello world” ,tensorflow已经为我们提供了提取成了特征向量的Mnist数据集。一共分三部分:训练集(train),检测集(validate), 测试集(test)。每个数据集都有images与labels作为图片特征向量与图片上对应的数字矩阵。images中每个元素是1*784的特征向量。labels中的每个元素是1*10的标签,并且对应的图片数字作为下标的值为1,其余为0。这样我们就知道了神经网络输入层与输出层的结点个数。


了解了数据集的构成后我们开始定义前向传播过程,在定义前向传播时,需要注意三点。
1. 在生成参数矩阵时,需要根据有无传入l2正则化函数,来将模型复杂损失度加入名为“loss”的集合中,在训练文件中定义损失函数后需要一起加上,从而减少过度拟合的影响。
2. 对于每一层网络中的参数矩阵,通过variable_scope去进行管理。在变量很多的神经网络中使用variable_scope会省去传过多参数给函数的麻烦。
3. 需要使用激活函数去线性化,这里使用的是Relu曲线。

import tensorflow as tf
#定义输入与输出结点个数与隐蔽层的结点个数
INPUT_NODE=784
LAYERS=500
OUTPUT_NODE=10

def node_mat(fore_node,cur_node,regularizer):
    mat = tf.get_variable("weights",
                          shape=[fore_node, cur_node],
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
    if regularizer!=None:
        losses=regularizer(mat)
        tf.add_to_collection("loss",losses)
    return mat

def inference(input_tensor,regulairzer):
    with tf.variable_scope("layer_1"):
        weight_1=node_mat(INPUT_NODE,LAYERS,regulairzer)
        mat_tp=tf.nn.relu(tf.matmul(input_tensor,weight_1)+tf.random_normal(shape=[LAYERS]
                                                                            ,stddev=0.1))
    with tf.variable_scope("layer_2"):
        weight_2=node_mat(LAYERS,OUTPUT_NODE,regulairzer)
        result=tf.matmul(mat_tp,weight_2)+tf.random_normal(shape=[OUTPUT_NODE],
                                                           stddev=0.1)
#这里的输出不需要再进行去线性化,因为训练文件中计算损失函数会进行softmax处理
    return result

前向传播完成后,进入训练文件的处理,这里还要使用三个优化模型的手段

  1. 使用了交叉熵的损失函数
  2. 指数衰减的学习率
  3. 滑动平均模型

在训练过程中,每隔一定的步数就需要保存目前的模型训练情况,将模型持久化。使模型可以复用。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference as mif
#一系列常量的定义
INPUT=784
OUTPUT=10
BATCH_SIZE=100
TRAIN_STEPS=30000
REGULARIZER=0.0001  #正则化的系数
LEARNING_RATE=0.8  #初始学习率
DECAY_RATE=0.99  #学习率的衰减率
MOVINGAVERAGE=0.99  #滑动平均的衰减率
SAVE_PATH="./mnist/"  #模型文件保存的父目录

def train(mnist):
    x=tf.placeholder(dtype=tf.float32,shape=[None,INPUT],name="x-input")
    y_=tf.placeholder(dtype=tf.float32,shape=[None,OUTPUT],name="y-output")

    regularizer=tf.contrib.layers.l2_regularizer(REGULARIZER)
    #l2正则化函数
    result=mif.inference(x,regularizer)

    cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=result))
    #注意交叉熵算完后需要求平均得到一个batch的平均交叉熵。
    losses=cross_entropy+tf.add_n(tf.get_collection("loss"))
    #此时将l2正则化的损失与交叉熵求和,得到最终的损失函数

    global_step=tf.Variable(0,name="global_step",trainable=False)
    learning_rate=tf.train.exponential_decay(LEARNING_RATE,global_step,BATCH_SIZE,DECAY_RATE)
    #定义指数衰减的学习率时需要定义global_step
    #该变量是不参与训练的,用于记录训练步数

    moving_average=tf.train.ExponentialMovingAverage(MOVINGAVERAGE)
    mov_avg_op=moving_average.apply(tf.trainable_variables())
    #滑动平均模型的设置,在训练过程中维护两个系数矩阵的影子变量

    train=tf.train.GradientDescentOptimizer(learning_rate).minimize(losses,global_step)
    #使用梯度下降的优化器,注意在minimize中要给出global_step
    #否则学习率是不会衰减的

    with tf.control_dependencies([train,mov_avg_op]):
        train_op=tf.no_op(name="train")
    #设置训练与维护影子变量这两个操作的运行顺序

    saver=tf.train.Saver()
    #设置Saver用于保存模型

    #定义完优化过程后建立会话运算
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for steps in range(TRAIN_STEPS):
            xs,ys=mnist.train.next_batch(BATCH_SIZE)
            #返回两个列表,得到一个batch的输入与标签
            sess.run(train_op,feed_dict={x:xs,y_:ys})
            if(steps%1000==0):
                print(sess.run(losses,feed_dict={x:xs,y_:ys}))
                #每1000步输出损失函数大小
                saver.save(sess,os.path.join(SAVE_PATH,"mnist.ckpt"),global_step=global_step)
                #这里使用global_step可以使文件名最后加上当前训练的轮数

def main(argv=None):
    mnist=input_data.read_data_sets("./mnist_data",one_hot=True)
    #在相应路径下寻找有没有mnist数据集文件,如果没有的话会自行下载(巨方便)
    train(mnist)

if __name__=="__main__":
    tf.app.run()
    #使用tf提供的主函数入口

至此我们一共使用了5种优化模型的方式:

  1. 激活函数去线性化
  2. l2正则化去过拟合
  3. 损失函数
  4. 指数衰减的学习率
  5. 滑动平均模型的设置

同时每隔一定的步数就保存模型,在大型模型训练的过程时,这么做极其重要。
这里给出前五步的训练结果:
这里写图片描述


最后我们将已经训练完成的模型进行载入,并利用测试数据集去检测我们的模型正确率情况。每5秒传入一个测试数据集,并输出一个正确率。在这里我们可以直接调用已经写过了的前向传播函数。
需要注意的是,此时的系数矩阵使用的是值的影子变量的值,所以在载入模型时要进行变量名的对应。滑动平均模型提供了variables_to_restore()函数,提供了影子变量名到变量名对应的字典。

import tensorflow as tf
import mnist_inference
import mnist_train
from tensorflow.examples.tutorials.mnist import input_data
import time

def validate(mnist):
    x=tf.placeholder(dtype=tf.float32,shape=[None,mnist_train.INPUT],name="x-input")
    y_=tf.placeholder(dtype=tf.float32,shape=[None,mnist_train.OUTPUT],name="y-output")

    result=mnist_inference.inference(x,None)
    correct=tf.equal(tf.argmax(y_,1),tf.argmax(result,1))
    correct_rate=tf.reduce_mean(tf.cast(correct,dtype=tf.float32))
    #argmax是求相应维度上最大元素的下标,从而得到对应的识别数字
    #cast将得到的列表里的bool值转化成float型(0或1),求均值即得正确率

    moving_average=tf.train.ExponentialMovingAverage(mnist_train.MOVINGAVERAGE)
    shadows_variables=moving_average.variables_to_restore()
    #得到影子变量到变量自身的一个字典
    #用于载入模型时,可以直接使用影子变量来传入前向传播

    saver=tf.train.Saver(shadows_variables)
    #声明Saver用于载入模型,参数表示进行变量名的对应

    with tf.Session() as sess:
        validation=mnist.validation
        feed_dict={x:validation.images,y_:validation.labels}
        ckpt=tf.train.get_checkpoint_state(mnist_train.SAVE_PATH)
        #get_checkpoint_state()函数会寻找模型文件目录下最新的模型文件
        while True:
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess,ckpt.model_checkpoint_path)
                print(sess.run(correct_rate,feed_dict))
            else:
                print("NO Checkpoint")
                return
            time.sleep(5)
            #每隔五秒跑一次,注意要导入time这个文件
def main(argv=None):
    mnist=input_data.read_data_sets('./mnist_data',one_hot=True)
    validate(mnist)

if __name__=="__main__":
    tf.app.run()

这里给出前五个测试情况的正确率
这里写图片描述
正确率还是比较高的,如果使用卷积神经网络还可以更高,到99%多。

参考了TensorFlow 实战Google深度学习框架

猜你喜欢

转载自blog.csdn.net/qq_38069320/article/details/82322288
今日推荐