RNN+LSTM笔记

Naive RNN

概述

循环神经网络(Recurrent Neural Network,RNN)是一种用于处理序列数据的神经网络。相比一般的神经网络来说,他能够处理序列变化的数据。比如某个单词的意思会因为上文提到的内容不同而有不同的含义,RNN就能够很好地解决这类问题,因此在机器翻译、问答系统等NLP领域有很重要的应用。

结构

在这里插入图片描述

其中,x表示当前状态下数据的输入,h表示接收到的上一个节点的输入,y表示当前状态下的输出,h’表示传递到下一个节点的输出。图中的公式表明了h‘是历史状态(长期记忆)h和当前状态的输入x的线性组合,而y则是对h’进行线性操作。

假设输入为一个序列,则RNN则可以展开成这样的结构:

在这里插入图片描述

问题

但是实际使用中,RNN具有很严重的梯度消失和梯度爆炸的问题,这是个实践问题,而不是一个理论问题,因为合适的参数一定存在,但是RNN的这组合适的参数不容易找到,这对于RNN来说犹如外卖软件不能点外卖一样致命。

LSTM

概述

LSTM的全称是Long short-term memory,即长短期记忆网络,它有效缓解了RNN中梯度消失的问题(梯度爆炸的问题可以从其他技术中得到缓解),因此在更长的序列上又更优秀的表现。

补充:更加严谨的叫法其实是 LSTM-RNN,即带有长短期记忆网络单元的RNN网络。

外部结构

从输入输出上,LSTM和传统RNN的对比如下所示

在这里插入图片描述

传统RNN只有一个传递状态h,但是LSTM有两个传递状态 c c c h h h,前者代表cell state,是在前一个状态输出的c的基础上调整来的,变化缓慢,所以代表长期记忆(long-term memory);而后者表示hidden state,因为作为和输入x拼接的变量,因此与相邻LSTM单元的输入值有关,变化较大,所以代表短期记忆(short-term memory)

可见,LSTM中的c反而和naive RNN的h更加相似,代表着long-term memory。

内部结构

首先,第t个LSTM单元的输入为当前输入 x t x^t xt和上一个状态传递下来的short-term memory h t − 1 h^{t-1} ht1拼接并训练得到的四个状态:

在这里插入图片描述
在这里插入图片描述

其中,zf,zi,zo是由拼接向量乘以权重矩阵之后,再通过一个 sigmoid激活函数转换成0到1之间的数值,来作为一种门控状态,分别表示遗忘门控、输入门控、输出门控。

而 z 则是将结果通过一个tanh激活函数将转换成-1到1之间的值(这里使用 tanh 是因为这里是将其做为输入数据,而不是门控信号,也有人认为,使用tanh也是传递状态c和h不同的本质原因)

下面则是LSTM的内部结构

在这里插入图片描述

⊙ \odot 代表Hadamard Product,即矩阵对应元素相乘,这要求两个矩阵是同型的, ⊕ \oplus 代表矩阵相加。

LSTM主要有三个阶段,分别对应三个门限:

  1. 忘记阶段。这个阶段主要是对上一个节点传进来的输入进行选择性忘记。简单来说就是会 “忘记不重要的,记住重要的”。具体来讲,就是 z f z^f zf控制 c t − 1 c^{t-1} ct1的哪些部分需要遗忘。

  2. 选择记忆阶段。这个阶段将这个阶段的输入有选择性地进行“记忆”。主要是会对输入 x t x^t xt进行选择记忆。哪些重要则着重记录下来,哪些不重要,则少记一些。当前的输入内容由前面计算得到的 z z z表示,而选择的门控信号则是由 z i z^i zi(有人认为i代表information,也有人认为i代表input)来进行控制。
    将上面两步得到的结果相加,即可得到传输给下一个状态的 c t c^t ct ,也就是上图中的第一个公式。

  3. 输出阶段。这个阶段将决定哪些将会被当成当前状态的输出。主要是通过 z o z^o zo来进行控制的,并且还对上一阶段得到的 c o c^o co进行了放缩(通过一个tanh激活函数进行变化)。

y t y^t yt的输出和 h t h^t ht有关,这一点和naive RNN很相似。

为什么LSTM缓解了RNN的梯度消失问题

为什么RNN很容易出现梯度消失

假设使用SGD train RNN模型的参数,则有:
w i + 1 = w i − r ∂ L o s s ∂ w ∣ w : w i w^{i+1}=w^i-r\frac{\partial Loss}{\partial w}|_{w:w^i} wi+1=wirwLossw:wi
而naive RNN模型的输出参数的表达式为
h ′ = σ ( w h ⋅ h + w i ⋅ x ) y = σ ( w o ⋅ h ′ ) h'=\sigma(w^h\cdot h+w^i\cdot x) \\ y=\sigma(w^o\cdot h') h=σ(whh+wix)y=σ(woh)
其中, w h , w i , w o w^h,w^i,w^o wh,wi,wo都是要学习的参数。

接下来我们计算损失函数,这里的损失函数要考虑从0时刻到t时刻的损失函数求和,又称为BPTT(Back Propagation Trough Time):
L = ∑ t = 0 T L t L=\sum_{t=0}^T{L_t} L=t=0TLt
计算损失函数的梯度:
∂ L ∂ W = ∑ t = 0 T ∂ L t ∂ W \frac{\partial L}{\partial W}=\sum_{t=0}^{T} \frac{\partial L_{t}}{\partial W} WL=t=0TWLt
分别列出损失函数对三个梯度的求导表达式:
∂ L ∂ W = ∑ t = 0 T ∂ L t ∂ W ∂ L t ∂ W o = ∑ t = 0 T ∂ L t ∂ y t ∂ y t ∂ W o ∂ L t ∂ W i = ∑ t = 0 T ∑ k = 0 t ∂ L t ∂ y t ∂ y t ∂ h t ( ∏ j = k + 1 t ∂ h j ∂ h j − 1 ) ∂ h k ∂ W i ∂ L t ∂ W h = ∑ t = 0 T ∑ k = 0 t ∂ L t ∂ y t ∂ y t ∂ h t ( ∏ j = k + 1 t ∂ h j ∂ h j − 1 ) ∂ h k ∂ W h \frac{\partial L}{\partial W}=\sum_{t=0}^{T} \frac{\partial L_{t}}{\partial W}\frac{\partial L_{t}}{\partial W^{o}}=\sum_{t=0}^{T} \frac{\partial L_{t}}{\partial y_{t}} \frac{\partial y_{t}}{\partial W^{o}}\\ \frac{\partial L_{t}}{\partial W^{i}}=\sum_{t=0}^{T} \sum_{k=0}^{t} \frac{\partial L_{t}}{\partial y_{t}} \frac{\partial y_{t}}{\partial h_{t}}\left(\prod_{j=k+1}^{t} \frac{\partial h_{j}}{\partial h_{j-1}}\right) \frac{\partial h_{k}}{\partial W^{i}}\\ \frac{\partial L_{t}}{\partial W^{h}}=\sum_{t=0}^{T} \sum_{k=0}^{t} \frac{\partial L_{t}}{\partial y_{t}} \frac{\partial y_{t}}{\partial h_{t}}\left(\prod_{j=k+1}^{t} \frac{\partial h_{j}}{\partial h_{j-1}}\right) \frac{\partial h_{k}}{\partial W^{h}} WL=t=0TWLtWoLt=t=0TytLtWoytWiLt=t=0Tk=0tytLthtytj=k+1thj1hjWihkWhLt=t=0Tk=0tytLthtytj=k+1thj1hjWhhk
可见,导致RNN容易梯度消失和爆炸的就是这里的连乘项,其本质是 h t h_t ht W i W^i Wi求导需要链式求导到最初的 h k h^k hk,为了进一步分析,我们将连乘项中的 h j h_j hj h j − 1 h_{j-1} hj1的导数展开:
∂ h j ∂ h j − 1 = σ ′ W h = σ ( 1 − σ ) W h \frac{\partial h_{j}}{\partial h_{j-1}}=\sigma^{\prime} W^{h}=\sigma(1-\sigma)W^h hj1hj=σWh=σ(1σ)Wh
其中, σ ( 1 − σ ) \sigma(1-\sigma) σ(1σ)的取值范围是(0,0.25),即如果 W h W^h Wh小于4,连乘式的每一项都是小于1的,就很容易发生梯度消失!反之,梯度爆炸则不那么容易发生。

LSTM如何避免了梯度消失

LSTM的BPTT展开式非常复杂,但是核心是将RNN中的连乘式替换为了 c j c_j cj c j − 1 c_{j-1} cj1的求导:
∏ j = k + 1 t ∂ h j ∂ h j − 1 → ∏ j = k + 1 t ∂ c j ∂ c j − 1 \prod_{j=k+1}^{t} \frac{\partial h_{j}}{\partial h_{j-1}}\to\prod_{j=k+1}^{t} \frac{\partial c_{j}}{\partial c_{j-1}} j=k+1thj1hjj=k+1tcj1cj
这也从侧面佐证了我们上面所说的,LSTM中的传递状态c和RNN中的传递状态h都代表着long-term memory。

忽略bias, c j c_j cj c j − 1 c_{j-1} cj1的关系式为
c j = ( z f ⊙ c j − 1 ) ⊕ ( z i ⊙ z ) = [ σ ( W f x j + b f ) ⊙ c j − 1 ] ⊕ [ σ ( W i x j + b i ) ⊙ tanh ⁡ ( W x j + b ) ] \begin{aligned} c_{j} &=\left(z_{f} \odot c_{j-1}\right) \oplus\left(z_{i} \odot z\right) \\ &=\left[\sigma\left(W^{f} x_{j}+b^{f}\right) \odot c_{j-1}\right] \oplus\left[\sigma\left(W^{i} x_{j}+b^{i}\right) \odot \tanh \left(W x_{j}+b\right)\right] \end{aligned} cj=(zfcj1)(ziz)=[σ(Wfxj+bf)cj1][σ(Wixj+bi)tanh(Wxj+b)]
c j c_j cj c j − 1 c_{j-1} cj1的连乘项等于 σ ( W f x j + b f ) \sigma(W^fx_j+b^f) σ(Wfxj+bf),其取值范围是(0,1),这就不是很容易发生梯度消失了。【注意:此时认为 z f , z i , z o z^f,z^i,z^o zf,zi,zo不是 c j − 1 c_{j-1} cj1的函数,这其实是很片面的,因为这三者都是 h j − 1 h_{j-1} hj1的函数,自然也是 c j − 1 c_{j-1} cj1的函数,下面有完整的梯度分析。】

在LSTM的原始论文中并没有 z f z_f zf这样一个控制遗忘的门控,这会造成cell的状态是不可控的,于是加上遗忘门控,而连乘项的截断梯度的估计正好是 f t f_t ft

补充:上面的梯度计算并不完整,这篇文章很好的解释了这个问题Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass (weberna.github.io)

符号假设为
C t = f t ∗ C t − 1 + i t ∗ C ~ t C_{t}=f_{t} * C_{t-1}+i_{t} * \tilde{C}_{t} Ct=ftCt1+itC~t
完整的梯度是
∂ h j ∂ h j − 1 = σ ′ W h ∂ C t ∂ C t − 1 = ∂ C t ∂ f t ∂ f t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ i t ∂ i t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ C t − 1 ∂ C t − 1 ∼ ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 \frac{\partial h_{j}}{\partial h_{j-1}}=\sigma^{\prime} W^{h}\frac{\partial C_{t}}{\partial C_{t-1}}=\frac{\partial C_{t}}{\partial f_{t}} \frac{\partial f_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial i_{t}} \frac{\partial i_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial C_{t-1}} \frac{\partial C_{t-1}^{\sim}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}} hj1hj=σWhCt1Ct=ftCtht1ftCt1ht1+itCtht1itCt1ht1+Ct1Ctht1Ct1Ct1ht1
进一步展开
∂ C t ∂ C t − 1 = C t − 1 σ ′ ( . ) W f ∗ o t − 1 tanh ⁡ ( C t − 1 ) + C ~ t σ ′ ( . ) W i ∗ o t − 1 tanh ⁡ ( C t − 1 ) + i t tanh ⁡ ′ ( . ) W c ∗ o t − 1 tanh ⁡ ( C t − 1 ) \frac{\partial C_{t}}{\partial C_{t-1}}=C_{t-1} \sigma^{\prime}(.) W_{f} * o_{t-1} \tanh \left(C_{t-1}\right)+\tilde{C}_{t} \sigma^{\prime}(.) W_{i} * o_{t-1} \tanh \left(C_{t-1}\right)+i_{t} \tanh ^{\prime}(.) W_{c} * o_{t-1} \tanh \left(C_{t-1}\right) Ct1Ct=Ct1σ(.)Wfot1tanh(Ct1)+C~tσ(.)Wiot1tanh(Ct1)+ittanh(.)Wcot1tanh(Ct1)
之前的结果等于 f t f_t ft,一直都是小于1,现在的结果可以大于1也可以小于1,而且 f t , i t , o t , C t ˉ f_t,i_t,o_t,\bar{C_t} ft,it,ot,Ctˉ都是网络自己学习的,因此可以很好的避免梯度消失的问题。

参考文献

LSTM介绍

【强烈推荐】 https://zhuanlan.zhihu.com/p/32085405

LSTM为什么能够缓解梯度消失

【截断梯度的分析,有些错误,注意甄别】 https://www.zhihu.com/question/44895610/answer/616818627

【完整的梯度分析】 https://zhuanlan.zhihu.com/p/109519044

猜你喜欢

转载自blog.csdn.net/weixin_43721070/article/details/121846328