使用TensorFlow的LSTM API(3/7)

LSTM 有一个 cell state和一个hidden state (这俩是同一个?)

替换


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()

        #  第一处替换
        # _current_state = np.zeros((batch_size, state_size))
        _current_cell_state = np.zeros((batch_size, state_size))
        _current_hidden_state = np.zeros((batch_size, state_size))   

        print("New data, epoch", epoch_idx)

TensorFlow在LSTM内部使用一个叫做LSTMStateTuple的数据结构,
第一个元素是 cell state, 第二个元素是 hidden state,
第二处替换,在创建init_state placeholder的时候


batchX_placeholder = tf.placeholder(tf.float32, [batch_size, truncated_backprop_length])
batchY_placeholder = tf.placeholder(tf.int32, [batch_size, truncated_backprop_length])

# 第二处替换
# init_state = tf.placeholder(tf.float32, [batch_size, state_size])
cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)

第三步替换,在创建LSTM cell,和构建states_series的时候

# 第三处替换
# cell = tf.nn.rnn_cell.BasicRNNCell(state_size)
# states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, initial_state = init_state)
cell = tf.nn.rnn_cell.BasicLSTMCell(state_size, state_is_tuple=True)
states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, initial_state = init_state)

最后,给placeholder喂数据

_total_loss, _train_step, _current_state, _predictions_series = sess.run(
    [total_loss, train_step, current_state, predictions_series],
    feed_dict={
        batchX_placeholder: batchX,
        batchY_placeholder: batchY,
        # init_state:_current_state,
        # 最后一处替换
        cell_state: _current_cell_state,
        hidden_state: _current_hidden_state

    })

猜你喜欢

转载自blog.csdn.net/sinat_24070543/article/details/75246482