tf.nn.dynamic_rnn 详解
参考: https://zhuanlan.zhihu.com/p/43041436
output, last_state = tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
name | shape |
---|---|
cell | int, lstm or gru的神经元数,与输出size有关 |
input | [batch_size, max_length, embedding_size] |
sequence_length | [int, int,…]对应输入序列的实际长度,应用于padding的非定长输入 |
output | [batch_size, max_length, cell] |
state | [batch_size, cell.output_size ] or [2, batch_size, cell.output_size ] |
output 和state的关系
以上两个图是lstm的结构,对应的last_state有【
】,cell_state(应该记住或遗忘的状态),
(实际的输出),因此state是【2, batch_size, cell】
对应中间的每一个状态【batch_size, max_length, cell_size】
last_state中的
对应的是output中最后一个输出(每一个输入最后一个不为0的部分)
例如:输入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【2,3,5】
GRU是LSTM修改的RNN,对应只有一个输出,以及向后层传递的
,所以state=【batch_size, cell_size】
同理,对于gru,例如:输入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【3,5】