LSTM在keras中的参数return_sequences和return_state

原文链接:https://blog.csdn.net/u011327333/article/details/78501054

Understand the Difference Between Return Sequences and Return States for LSTMs in Keras

Kears LSTM API 中给出的两个参数描述

  • return_sequences:默认 False。在输出序列中,返回单个 hidden state值还是返回全部time step 的 hidden state值。 False 返回单个, true 返回全部。
  • return_state:默认 False。是否返回除输出之外的最后一个状态。

区别 cell state 和 hidden state

LSTM 的网络结构中,直接根据当前 input 数据,得到的输出称为 hidden state。
还有一种数据是不仅仅依赖于当前输入数据,而是一种伴随整个网络过程中用来记忆,遗忘,选择并最终影响 hidden state 结果的东西,称为 cell state。 cell state 就是实现 long short memory 的关键。

这里写图片描述

如图所示, C 表示的就是 cell state。h 就是hidden state。(选的图不太好,h的颜色比较浅)。整个绿色的矩形方框就是一个 cell。

cell state 是不输出的,它仅对输出 hidden state 产生影响。

通常情况,我们不需要访问 cell state,除非想设计复杂的网络结构时。例如在设计 encoder-decoder 模型时,我们可能需要对 cell state 的初始值进行设定。

keras 中设置两种参数的讨论

1.return_sequences=False && return_state=False

h = LSTM(X)
  
  
  • 1

Keras API 中,return_sequences和return_state默认就是false。此时只会返回一个hidden state 值。如果input 数据包含多个时间步,则这个hidden state 是最后一个时间步的结果

2.return_sequences=True && return_state=False

LSTM(1, return_sequences=True)
  
  
  • 1

输出的hidden state 包含全部时间步的结果。

3.return_sequences=False && return_state=True

lstm1, state_h, state_c = LSTM(1, return_state=True)
  
  
  • 1

lstm1 和 state_h 结果都是 hidden state。在这种参数设定下,它们俩的值相同。都是最后一个时间步的 hidden state。 state_c 是最后一个时间步 cell state结果。

为什么要保留两个值一样的参数? 马上看配置4就会明白

为了便于说明问题,我们给配置3和配置4一个模拟的结果,程序结果参考reference文献。

[array([[ 0.10951342]], dtype=float32), # lstm1
 array([[ 0.10951342]], dtype=float32), # state_h
 array([[ 0.24143776]], dtype=float32)] # state_c
  
  
  • 1
  • 2
  • 3

3.return_sequences=True && return_state=True

lstm1, state_h, state_c = LSTM(1, return_sequences=True, return_state=True)
  
  
  • 1

此时,我们既要输出全部时间步的 hidden state ,又要输出 cell state。

lstm1 存放的就是全部时间步的 hidden state。

state_h 存放的是最后一个时间步的 hidden state

state_c 存放的是最后一个时间步的 cell state

一个输出例子,假设我们输入的时间步 time step=3

[array([[[-0.02145359],
        [-0.0540871 ],
        [-0.09228823]]], dtype=float32),
 array([[-0.09228823]], dtype=float32),
 array([[-0.19803026]], dtype=float32)]
  
  
  • 1
  • 2
  • 3
  • 4
  • 5

可以看到state_h 的值和lstm1的最后一个时间步的值相同。

state_c 则表示最后一个时间步的 cell state

发布了5 篇原创文章 · 获赞 5 · 访问量 403

猜你喜欢

转载自blog.csdn.net/rocking_struggling/article/details/104318023