tensorflow使用多层RNN(lstm)预测手写数字实现部分细节及踩坑总结

输入格式:batch_size*784改成batch_size*28*28,28个序列,内容是一行的28个灰度数值。

让神经网络逐行扫描一个手写字体图案,总结各行特征,通过时间序列串联起来,最终得出结论。

网络定义:单独定义一个获取单元的函数,便于在MultiRNNCell中调用,创建多层LSTM网络

def get_a_cell(i):
    lstm_cell =rnn.BasicLSTMCell(num_units=HIDDEN_CELL, forget_bias = 1.0, state_is_tuple = True, name = 'layer_%s'%i)
    print(type(lstm_cell))
    dropout_wrapped = rnn.DropoutWrapper(cell = lstm_cell, input_keep_prob = 1.0, output_keep_prob = keep_prob)
    return dropout_wrapped

multi_lstm = rnn.MultiRNNCell(cells = [get_a_cell(i) for i in range(LSTM_LAYER)],
                              state_is_tuple=True)#tf.nn.rnn_cell.MultiRNNCell

简单说一下其他细节和坑:RNN有不同的运行方法,最简单的是用dynamic,直接吃结果。

outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
final_out = outputs[:,-1,:]

也可以写个循环手动运行seq_num次,得到最终结果(下面两种形式,反正本质都是调用__call__):

outputs = list()
state = init_state
with tf.variable_scope('RNN'):
    for timestep in range(STEP_SIZE):
        # (cell_output, state) = multi_lstm(tf_x_reshaped[:,timestep,:],state)
        (cell_output, state) = multi_lstm.call(tf_x_reshaped[:,timestep,:],state)
        outputs.append(cell_output)
        # print('cell_output:', cell_output)
h_state = outputs[-1]

batch_size对RNN是有影响的,因为LSTM有0号状态需要初始化的,这个是和batch_size挂钩的。所以最好把batch_size用placeholder输入,而不是常量。

这里测试集传入数据的batch_size想比训练传入数据的batch_size大一些,就会报错!当然,如果我懒得改代码,也可以用个小循环多次取小数据集的结果,最后取平均。

init_state = multi_lstm.zero_state(batch_size = BATCH_SIZE, dtype = tf.float32)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.summary.FileWriter('graph', graph=sess.graph)
    for i in range(2000):
        x,y = MNIST.train.next_batch(BATCH_SIZE)
        _, loss_,outputs_, state_, right_predictions_num_  = \
            sess.run([train_op, cross_entropy,outputs, state,right_predictions_num], {tf_x:x, tf_y:y, keep_prob:1.0})
        print('loss:', loss_)
        # print('right_predictions_num_:', right_predictions_num_)

        if i % 200 == 0:
            # tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp: Dimensions of inputs should match: shape[0] = [1000, 28] vs.shape[1] = [100, 256]
            # test_x, test_y = MNIST.test.next_batch(BATCH_SIZE * 10)
            total_accuracy = 0.
            total_test_batch = 10
            for j in range(total_test_batch):
                test_x, test_y = MNIST.test.next_batch(BATCH_SIZE)
                accuracy_ =  sess.run([accuracy], {tf_x:test_x, tf_y:test_y, keep_prob:1.0})
                total_accuracy += accuracy_[0]
            total_accuracy = total_accuracy / total_test_batch
            print('total_accuracy:', total_accuracy)

本例实现代码:

https://github.com/huqinwei/tensorflow_demo/blob/master/lstm_mnist/multi_lstm.py

lstm多层结构state的存在形式:

https://blog.csdn.net/huqinweI987/article/details/83148239

猜你喜欢

转载自blog.csdn.net/huqinweI987/article/details/83155110