【Deep learning】长短时记忆网络LSTM

写在前面

在前面讲的【Deep learning】循环神经网络RNN中,我们对RNN模型做了总结。由于RNN也有梯度消失的问题,因此很难处理长序列的数据,大牛们对RNN做了改进,得到了RNN的特例LSTM(Long Short-Term Memory),它可以避免常规RNN的梯度消失,因此在工业界得到了广泛的应用。下面我们就对LSTM模型做一个总结。

1.从RNN到LSTM


其中上图是传统RNN结构框架,而下图为LSTM结构框架,对比两图可以发现LSTM比传统RNN要复杂得多。它也是一种特殊的循环体结构,拥有三个“门”结构的特殊网络结构:遗忘门、输入门和输出门。下面就将具体介绍每一种门都是怎么工作的。

2.遗忘门

遗忘门(forget gate)顾名思义,是控制是否遗忘的,在LSTM中即以一定的概率控制是否遗忘上一层的隐藏细胞状态。举个栗子,比如一段文章中先介绍了某地原来是绿水蓝天,但是后来被污染了。于是在看到被污染了之后,循环神经网络就会“忘记”了之前绿水蓝天的状态。这就是“遗忘门”工作内容。

遗忘门子结构如下图所示:


图中输入的有上一序列的隐藏状态h(t−1)和本序列数据x(t),通过一个激活函数,一般是sigmoid,得到遗忘门的输出f(t)。由于sigmoid的输出f(t)在[0,1]之间,因此这里的输出f^{(t)}代表了遗忘上一层隐藏细胞状态的概率。用数学表达式即为:

其中Wf,Uf,bf为线性关系的系数和偏倚,和RNN中的类似。σ为sigmoid激活函数。

3.输入门

在RNN经历“遗忘门”之后,它还需要从当前的输入来补充最新的记忆,这就需要“输入门”来完成。

输入门(input gate)负责处理当前序列位置的输入,它的子结构如下图:

从图中可以看到输入门由两部分组成,第一部分使用了sigmoid激活函数,输出为i(t),第二部分使用了tanh激活函数,输出为a(t), 两者的结果后面会相乘再去更新细胞状态。用数学表达式即为:


其中Wi,Ui,bi,Wa,Ua,ba,为线性关系的系数和偏倚,和RNN中的类似。σ为sigmoid激活函数。

4. Cell状态更新

在研究LSTM输出门之前,我们要先看看LSTM之细胞状态。前面的遗忘门和输入门的结果都会作用于细胞状态C(t)。我们来看看从细胞状态C(t−1)如何得到C(t)。如下图所示:


细胞状态C(t)由两部分组成,第一部分是C(t−1)和遗忘门输出f(t)的乘积,第二部分是输入门的i(t)和a(t)的乘积,即:


其中,⊙为Hadamard积,在DNN中也用到过。

5.输出门

有了新的隐藏细胞状态C(t),我们就可以来看输出门了,子结构如下:


从图中可以看出,隐藏状态h(t)的更新由两部分组成,第一部分是o(t), 它由上一序列的隐藏状态h(t−1)和本序列数据x(t),以及激活函数sigmoid得到,第二部分由隐藏状态C(t)和tanh激活函数组成, 即:


6. LSTM前向传播算法

现在我们来总结下LSTM前向传播算法。LSTM模型有两个隐藏状态h(t),C(t),模型参数几乎是RNN的4倍,因为现在多了Wf,Uf,bf,Wa,Ua,ba,Wi,Ui,bi,Wo,Uo,bo这些参数。


7.LSTM反向传播算法

有了LSTM前向传播算法,推导反向传播算法就很容易了, 思路和RNN的反向传播算法思路一致,也是通过梯度下降法迭代更新我们所有的参数,关键点在于计算所有参数基于损失函数的偏导数。

在RNN中,为了反向传播误差,我们通过隐藏状态h(t)的梯度δ(t)一步步向前传播。在LSTM这里也类似。只不过我们这里有两个隐藏状态h(t)和C(t)。这里我们定义两个δ,即:


反向传播时只使用了δ(t)C,变量δ(t)h仅为帮助我们在某一层计算用,并没有参与反向传播,这里要注意。如下图所示:


而在最后的序列索引位置τ的δ(τ)h和 δ(τ)C为:


接着我们由δ(t+1)C反向推导δ(t)C。

δ(t)h的梯度由本层的输出梯度误差决定,即:


而δ(t)C的反向梯度误差由前一层δ(t+1)C的梯度误差和本层的从h(t)传回来的梯度误差两部分组成,即:


有了δ(t)h和δ(t)C, 计算这一大堆参数的梯度就很容易了,这里只给出Wf的梯度计算过程,其他的Uf,bf,Wa,Ua,ba,Wi,Ui,bi,Wo,Uo,bo,V,c的梯度大家只要照搬就可以了。



小结

LSTM虽然结构复杂,但是只要理顺了里面的各个部分和之间的关系,进而理解前向反向传播算法是不难的。当然实际应用中LSTM的难点不在前向反向传播算法,这些有算法库帮你搞定,模型结构和一大堆参数的调参才是让人头痛的问题。不过,理解LSTM模型结构仍然是高效使用的前提。

一下参考资料可供参考:

Understanding LSTM Networks

循环神经网络(RNN, Recurrent Neural Networks)介绍

LSTM模型与前向反向传播算法


猜你喜欢

转载自blog.csdn.net/Kaiyuan_sjtu/article/details/80646157