tensorflow关于lstm/gru实现细节

tf.nn.dynamic_rnn 详解

参考: https://zhuanlan.zhihu.com/p/43041436

output, last_state = tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)
name shape
cell int, lstm or gru的神经元数,与输出size有关
input [batch_size, max_length, embedding_size]
sequence_length [int, int,…]对应输入序列的实际长度,应用于padding的非定长输入
output [batch_size, max_length, cell]
state [batch_size, cell.output_size ] or [2, batch_size, cell.output_size ]
output 和state的关系

在这里插入图片描述
在这里插入图片描述
以上两个图是lstm的结构,对应的last_state有【 c t , h t c_t, h_t 】,cell_state(应该记住或遗忘的状态), h t h_t (实际的输出),因此state是【2, batch_size, cell】
c t c_t 对应中间的每一个状态【batch_size, max_length, cell_size】
last_state中的 h t h_t 对应的是output中最后一个输出(每一个输入最后一个不为0的部分)

例如:输入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【2,3,5】

在这里插入图片描述
GRU是LSTM修改的RNN,对应只有一个输出,以及向后层传递的 h t h_t ,所以state=【batch_size, cell_size】

同理,对于gru,例如:输入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【3,5】

tf.nn.bidirectional_dynamic_rnn

发布了98 篇原创文章 · 获赞 9 · 访问量 4万+

猜你喜欢

转载自blog.csdn.net/qq_40168949/article/details/100535013