LSTM 有一个 cell state和一个hidden state (这俩是同一个?)
替换
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
plt.ion()
plt.figure()
plt.show()
loss_list = []
for epoch_idx in range(num_epochs):
x,y = generateData()
# 第一处替换
# _current_state = np.zeros((batch_size, state_size))
_current_cell_state = np.zeros((batch_size, state_size))
_current_hidden_state = np.zeros((batch_size, state_size))
print("New data, epoch", epoch_idx)
TensorFlow在LSTM内部使用一个叫做LSTMStateTuple
的数据结构,
第一个元素是 cell state, 第二个元素是 hidden state,
第二处替换,在创建init_state placeholder的时候
batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])
# 第二处替换
# init_state = tf.placeholder(tf.float32, [batch_size, state_size])
cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)
第三步替换,在创建LSTM cell,和构建states_series的时候
# 第三处替换
# cell = tf.nn.rnn_cell.BasicRNNCell(state_size)
# states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, initial_state = init_state)
cell = tf.nn.rnn_cell.BasicLSTMCell(state_size, state_is_tuple=True)
states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, initial_state = init_state)
最后,给placeholder喂数据
_total_loss, _train_step, _current_state, _predictions_series = sess.run(
[total_loss, train_step, current_state, predictions_series],
feed_dict={
batchX_placeholder: batchX,
batchY_placeholder: batchY,
# init_state:_current_state,
# 最后一处替换
cell_state: _current_cell_state,
hidden_state: _current_hidden_state
})