(八)Tensorflow的LSTM对(MNIST)图像数据的处理

长短时记忆模型LSTM是对经典的RNN记忆力弱的缺点进行改进的一种模块,一般用于时序序列(文本,语音)数据,

但是图片数据处理效果也比较好,就是给图片数据加入记忆元素效果也不错!

这个模块比较简单,主要复杂在于,神经网络维度的unmatch,以及API的熟悉程度,LSTM模块:主要是隐层的状态和

数据reshape以及返回值的使用,参数对应的是隐藏层来说的!训练模块:就是构建loss的可视化,以及模型的保存!

1.定值和库的展示

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import rnn
import os,math

MODEL_PATH = "./lstm_model"
MODEL_NAME = "model.ckpt"
TRAIN_NUM_ITERATION = 10000
LEARNING_RATE = 0.001
init_iteration = 4000  #每次修改最新的iteration,否则会报错!!!

NUM_STEP = 1000 #每训练一定次数进行测试和计算损失
INPUT_SIZE = 784
OUTPUT_SIZE = 10
HIDDEN_SIZE = 128 #时间序列之间的隐藏层,“门”
TIME_STEP = 28  #时间序列的长度,“把28个图像维度当成一整个时间序列”
DATA_SIZE = 28
BATCH_SIZE = 128

2.模型展示

class LSTM_MODEL:
    def __init__(self,input_size,output_size,hidden_size,time_step,batch_size,learning_rate,data_size):
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.time_step = time_step
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.data_size = data_size

    def _create_data(self):
        "利用placeholder来给数据创建占位符"
        self.x = tf.placeholder(tf.float32,[self.batch_size,self.input_size],name="x")
        self.y = tf.placeholder(tf.float32,[self.batch_size,self.output_size],name="y")
        return

    def _create_lstm(self):
        "创建网络结构,并进行输出最终结果"
        """
        # input_data = tf.reshape(self.x,[self.batch_size,self.time_step,self.data_size])
        # lstm_cell_state = rnn.BasicLSTMCell(self.hidden_size)
        # outputs, final_state = tf.nn.dynamic_rnn(lstm_cell_state, input_data, dtype=tf.float32)
        #取outputs的话需要reshape下[batch_size, max_time, cell.output_size]`.
        #final_state是[batch_size, cell.state_size]`
        # self.outputs = final_state[1]  # 1是取到最后一个输出结果,0是隐藏层输出的结果是记忆信息
        # self.outputs = outputs[:, -1]  # 或者这样,一定注意
        """
        input_data = tf.reshape(self.x, [self.batch_size, self.time_step, self.data_size]) #先reshape下
        input_data = tf.unstack(input_data,self.time_step,axis=1) #在按时间序列转化为序列格式
        lstm_cell_state = rnn.BasicLSTMCell(self.hidden_size)
        #outputs is a length T list of outputs ,取最后一个即可:-1
        outputs, final_state = rnn.static_rnn(lstm_cell_state, input_data, dtype=tf.float32)
        self.outputs = final_state[1] # 1是取到最后一个输出结果,0是隐藏层输出的结果是记忆信息  或者  self.outputs = outputs[-1]
        weights = tf.Variable(tf.truncated_normal([self.hidden_size,self.output_size],stddev=0.1),name="weight")
        biases = tf.Variable(tf.zeros([self.output_size]),name="biases",dtype=tf.float32)
        self.outputs = tf.add(biases,tf.matmul(self.outputs,weights))
        return

    def _create_loss(self):
        "定义损失函数,并定义优化器"
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.outputs,labels=self.y))
        self.train = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss)
        return

    def _create_pred(self):
        "定义测试样本准确率"
        pred = tf.equal(tf.argmax(self.y,axis=1),tf.argmax(self.outputs,axis=1))
        self.accuary = tf.reduce_mean(tf.cast(pred,tf.float32))  #这里


    def _create_summary(self):
        "定义summary,用于可视化"
        tf.summary.scalar("loss",self.loss)
        tf.summary.histogram("histogram",self.loss)
        self.summary_op = tf.summary.merge_all()
        return

    def build_graph(self):
        "创建定义好的图模型"
        self._create_data()
        self._create_lstm()
        self._create_loss()
        self._create_pred()
        self._create_summary()
        return

3.训练模块展示

def train(model,train_num_iteration,model_path,model_name,batch_size):
    #将模型的训练和测试放在一起了
    mnist = input_data.read_data_sets('/data/mnist', one_hot=True)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        count = 0
        if os.path.exists(model_path):
            path = os.path.join(model_path,model_name)+"-"+str(init_iteration)
            saver.restore(sess,path)
        else:
            os.mkdir(model_path)
            sess.run(init)
        writer = tf.summary.FileWriter("./lstm_graph",sess.graph)
        loss_total = 0
        for i in range(init_iteration,train_num_iteration):
            #训练
            count += 1
            x_input,y_input = mnist.train.next_batch(batch_size)
            feed_dict = {model.x:x_input,model.y:y_input}
            _,l,summary = sess.run([model.train,model.loss,model.summary_op],feed_dict=feed_dict)
            loss_total+=l
            writer.add_summary(summary,global_step=i+1)
            #测试
            if count ==NUM_STEP:
                n_num = math.ceil( mnist.test.num_examples/batch_size )
                accuary_total = 0
                for _ in range( n_num):
                    x_input,y_input = mnist.test.next_batch(batch_size)
                    feed_dict = {model.x: x_input, model.y: y_input}
                    accuary_total+=sess.run(model.accuary,feed_dict=feed_dict)
                print("Iteratin:{},train_loss:{},accuary:{}".format(i+1,loss_total/NUM_STEP,accuary_total/n_num))
                count=0
                loss_total = 0
                #保存模型
                saver.save(sess,os.path.join(model_path,model_name),global_step=i+1)
        writer.close()
    return

4.启动代码

if __name__ == '__main__':
    model = LSTM_MODEL(INPUT_SIZE,OUTPUT_SIZE,HIDDEN_SIZE,TIME_STEP,BATCH_SIZE,LEARNING_RATE,DATA_SIZE)
    model.build_graph()
    train(model,TRAIN_NUM_ITERATION,MODEL_PATH,MODEL_NAME,BATCH_SIZE)

猜你喜欢

转载自blog.csdn.net/taka_is_beauty/article/details/89078226
今日推荐