使用LSTM实现mnist手写数字分类识别 TensorFlow

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/kl1411/article/details/82877415

RNN做图像识别原理:MNIST数据集中一张图片数据包含28*28的像素点。RNN是将一张图片数据的一行作为一个向量总体输入一个X中。也就是说,RNN有28个输入X,一个输入X有28个像素点。
输出最后一个结果做为预测值。
 

TensorFlow入门学习代码:

# -*- coding: utf-8 -*-
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os.path as ops
import os

# LOG 等级说明
# '1' 这是默认的显示等级,显示所有信息
# '2' 只显示 warning 和 Error
# '3' 只显示 Error
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 设置 LOG 输出等级  

tf.reset_default_graph()  # 清除默认图形堆栈并重置全局默认图形

# 读取数据
mnist = input_data.read_data_sets('./mnist', one_hot=True)

# 定义超参数
training_iters = 20001
batch_size = 100  

n_inputs = n_steps = 28  # n_inputs是每步输入值,对应图像列数; n_steps是时间步数,对应图像行数
n_hidden_number = 128  #隐藏层神经元个数
n_outputs = 10  #输出层神经元个数,对应数字的10个类

x = tf.placeholder(tf.float32,[None,n_steps,n_inputs],name='x')  # 添加一个新的占位符用于输入正确值, 使用placeholder()来传递一个tensor到session.run()中
Y = tf.placeholder(tf.float32,[None,n_outputs],name='Y')

# 初始化参数
weights = {
        # shape = (28,128)  tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)  shape: 输出张量的形状,必选
        'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_number])),  # tf.random_normal从服从指定正太分布的数值中取出指定个数的值
        # shape = (128,10)
        'out':tf.Variable(tf.random_normal([n_hidden_number,n_outputs]))
        }

biases = {
        # shape = (128,)
        'in':tf.Variable(tf.constant(0.1,shape = [n_hidden_number,])),  # tf.constant生成一个给定值的常量
        # shape = (10,)
        'out':tf.Variable(tf.constant(0.1,shape = [n_outputs,]))
        }

# 模型
def RNN(X, weights, biases):
    ### 输入层到核运算 ###
    # X shape = (100batch,28steps,28inputs) ==> (100batch*28steps,28inputs)
    X = tf.reshape(X,[-1,n_inputs])  # 将tensor变换为参数shape的形式  -1表示自动计算
    # X_in shape ==> (100batch*28steps,128hidden)
    X_in = tf.matmul(X,weights['in'])+biases['in']  #tf.matmul 矩阵乘法
    # X_in shape ==> (100batch,28steps,128hidden)
    X_in = tf.reshape(X_in,[-1,n_steps,n_hidden_number])
    
    ### cell核内运算 ###
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_number,forget_bias = 1.0)  # 选择的cell是BasicLSTMCell
    # LSTM cell is divided into two parts-->(c_state,m_state)
    # RNN的中间状态会得到两部分——一个是当前输出outputs,另一个是下时刻的记忆states,RNN在init_state中用c_state、m_state分别保存这两部分
    init_state = lstm_cell.zero_state(batch_size,dtype=tf.float32)  
    # 用dynamic_rnn的方法,用输入值X_in进行核内运算,将输出分别存入相应数组中
    outputs,states = tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=init_state,time_major=False)
    
    ### 核内运算到输出层 ###
    # 在核内以X_in为输入,得到输出outputs与states,当所有行都代入计算得到最后的输出预测值。其中states[1] = outputs[-1],相当于最后一个输出值。
    result = tf.matmul(states[1],weights['out'])+biases['out']
    return  result

# 保存模型
saver = tf.train.Saver(max_to_keep=2)
# 保存模型的路径
ckpt_file_path = "./models/"  # models是文件夹,mnist是文件命名使用的
path = os.path.dirname(os.path.abspath(ckpt_file_path))
if os.path.isdir(path) is False:
    os.makedirs(path)

# 代价函数和优化器
prediction = RNN(x,weights,biases)
tf.add_to_collection('predict', prediction)
#二次代价函数:预测值与真实值的误差
# softmax_cross_entropy_with_logits分为两步:第一步是先对网络最后一层的输出做一个softmax,第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=Y, logits=prediction))  # labels:实际的标签,logits:就是神经网络最后一层的输出
#梯度下降法:数据庞大,选用AdamOptimizer优化器
train_step = tf.train.AdamOptimizer(1e-3).minimize(loss)

#结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(Y,1))
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32),name="accuracy")

# 训练
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    step = 0
    
    while step*batch_size < training_iters:
        batch_xs,batch_ys = mnist.train.next_batch(batch_size)
        # batch_xs shape = [100,28,28]
        batch_xs = batch_xs.reshape([batch_size,n_steps,n_inputs])
#        batch_ys = batch_ys.reshape([batch_size,n_outputs])
    
        train_step.run(feed_dict = {x:batch_xs,Y:batch_ys,})  # 使用feed_dict来传入tensor
#        sess.run([train_step], feed_dict = {x:batch_xs, Y:batch_ys,})
        
        if step % 50 == 0:
            train_accuracy = accuracy.eval(feed_dict={x:batch_xs,Y:batch_ys,})  # 评估模型,得出训练的准确率
            print("step", step, "training accuracy", train_accuracy)
            
        if step % 100 == 0:
            model_name = 'mnist_{:s}'.format(str(step+1))
            model_save_path = ops.join(ckpt_file_path, model_name)
            saver.save(sess, model_save_path, write_meta_graph=True)  # 保存模型 
            
        step += 1       

    # 用测试数据再评估下
    test_len = 100  # 必须跟batch_size相等
    test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_inputs))
    test_label = mnist.test.labels[:test_len]
    print ("Testing Accuracy:", \
           sess.run(accuracy, feed_dict={x: test_data, Y: test_label}))

加载模型代码:

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

mnist = input_data.read_data_sets("./mnist", one_hot=True)

n_input = 28
n_steps = 28
n_classes = 10

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('./models/mnist_201.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./models'))
    print("model loaded")

    graph = tf.get_default_graph()
    predict = tf.get_collection('predict')[0]
    input_x = graph.get_operation_by_name("x").outputs[0]
    y = graph.get_tensor_by_name("Y:0")
    Accuracy = graph.get_tensor_by_name("accuracy:0")
    print("arg init")

    x = mnist.test.images[0].reshape((-1, n_steps, n_input))
    y = mnist.test.labels[0].reshape(-1, n_classes)  # 转为one-hot形式
    print("reshape done")   
    
    res = sess.run(predict, feed_dict={input_x: x })
    print("demo run")

    print("Actual class: ", str(sess.run(tf.argmax(y, 1))), \
          ", predict class ",str(sess.run(tf.argmax(res, 1))), \
          ", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(res, 1))))
          )

猜你喜欢

转载自blog.csdn.net/kl1411/article/details/82877415