使用多层 LSTM API(4/7)

这一次我们会让架构层次更深,使用LSTM多层结构

需要注意的是,在网络的每一层,我们都需要一个hidden state和一个cell state,
特别的是,输入到下一个LSTM层的输入,是那一个特定层的前一个状态,
隐藏的前一层的激活层也是
[这尼玛说的是啥?]

我们要 把每一层的states保存起来,将会有很多个LSTMTuples
为了方便,我们会用一个整的状态来代替之前的_current_cell_state_current_hidden_state

_current_state = np.zeros((num_layers, 2, batch_size, state_size))

这里num_layers=3

2代表了2个states,cell和hidden

现在修改之前的代码

_total_loss, _train_step, _current_state, _predictions_series = sess.run(
    [total_loss, train_step, current_state, predictions_series],
    feed_dict={
        batchX_placeholder: batchX,
        batchY_placeholder: batchY,
        # 这里
        init_state: _current_state
    })

然后替换这里,换成一个 tensor

#cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
#hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
#init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)
init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])

既然TensorFlow的多层api接受 作为LSTMTuples的state, 我们需要对state的数据结构动些手脚

对于状态里的每一层,我们创建一个LSTMTuple,然后把它们放到一个元组(tuple)里,在init_state后面加上:

state_per_layer_list = tf.unstack(init_state, axis=0)
rnn_tuple_state = tuple(
    [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1])
     for idx in range(num_layers)]
)

然后forward pass的部分修改成

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, initial_state=rnn_tuple_state)

[cell]* num_layers是对cell进行复制num次

多层LSTM一开始被用来创建一个单个的LSMTCell
然后在一个数组里复制这个cell
把它提供给MultiRNNCell的api调用

TensorFlow1.2的api调整

不能写成[cell]* num_layers来复制成多层cell了,
会提示ValueError: Trying to share variable rnn ...
每个cell都要单独生成,在放在同一个list里

def lstm_cell():
    return tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
stacked_lstm_cell = [lstm_cell() for _ in range(num_layers)]

cell = tf.nn.rnn_cell.MultiRNNCell(stacked_lstm_cell, state_is_tuple=True)

猜你喜欢

转载自blog.csdn.net/sinat_24070543/article/details/75255852