从循环神经网络(RNN)到LSTM网络

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/diligent_321/article/details/53365621

  通常,数据的存在形式有语音、文本、图像、视频等。因为我的研究方向主要是图像识别,所以很少用有“记忆性”的深度网络。怀着对循环神经网络的兴趣,在看懂了有关它的理论后,我又看了Github上提供的tensorflow实现,觉得收获很大,故在这里把我的理解记录下来,也希望对大家能有所帮助。本文将主要介绍RNN相关的理论,并引出LSTM网络结构(关于对tensorflow实现细节的理解,有时间的话,在下一篇博文中做介绍)。

循环神经网络

  RNN,也称作循环神经网络(还有一种深度网络,称作递归神经网络,读者要区别对待)。因为这种网络有“记忆性”,所以主要是应用在自然语言处理(NLP)和语音领域。与传统的Neural network不同,RNN能利用上”序列信息”。从理论上讲,它可以利用任意长序列的信息,但由于该网络结构存在“消失梯度”问题,所以在实际应用中,它只能回溯利用与它接近的time steps上的信息。

1. 网络结构

  常见的神经网络结构有卷积网络、循环网络和递归网络,栈式自编码器和玻尔兹曼机也可以看做是特殊的卷积网络,区别是它们的损失函数定义成均方误差函数。递归网络类似于数据结构中的树形结构,且其每层之间会有共享参数。而最为常用的循环神经网络,它的每层的结构相同,且每层之间参数完全共享。RNN的缩略图和展开图如下,

  尽管RNN的网络结构看上去与常见的前馈网络不同,但是它的展开图中信息流向也是确定的,没有环流,所以也属于forward network,故也可以使用反向传播(back propagation)算法来求解参数的梯度。另外,在RNN网络中,可以有单输入、多输入、单输出、多输出,视具体任务而定。

2. 损失函数

  在输出层为二分类或者softmax多分类的深度网络中,代价函数通常选择交叉熵(cross entropy)损失函数,前面的博文中证明过,在分类问题中,交叉熵函数的本质就是似然损失函数。尽管RNN的网络结构与分类网络不同,但是损失函数也是有相似之处的。
  假设我们采用RNN网络构建“语言模型”,“语言模型”其实就是看“一句话说出来是不是顺口”,可以应用在机器翻译、语音识别领域,从若干候选结果中挑一个更加靠谱的结果。通常每个sentence长度不一样,每一个word作为一个训练样例,一个sentence作为一个Minibatch,记sentence的长度为T。为了更好地理解语言模型中损失函数的定义形式,这里做一些推导,根据全概率公式,则一句话是“自然化的语句”的概率为

p(w1,w2,...,wT)=p(w1)×p(w2|w1)×...×p(wT|w1,w2,...,wT1)
  所以语言模型的目标就是最大化 P(w1,w2,...,wT) 。而损失函数通常为最小化问题,所以可以定义
Loss(w1,w2,...,wT|θ)=logP(w1,w2,...,wT|θ)
  那么公式展开可得
Loss(w1,w2,...,wT|θ)=(logp(w1)+logp(w2|w1)+...+logp(wT|w1,w2,...,wT1))
  展开式中的每一项为一个softmax分类模型,类别数为所采用的词库大小(vocabulary size),相信大家此刻应该就明白了,为什么使用RNN网络解决语言模型时,输入序列和输出序列错了一个位置了。

3. 梯度求解

  在训练任何深度网络模型时,求解损失函数关于模型参数的梯度,应该算是最为核心的一步了。在RNN模型训练时,采用的是BPTT(back propagation through time)算法,这个算法其实实质上就是朴素的BP算法,也是采用的“链式法则”求解参数梯度,唯一的不同在于每一个time step上参数共享。从数学的角度来讲,BP算法就是一个单变量求导过程,而BPTT算法就是一个复合函数求导过程。接下来以损失函数展开式中的第3项为例,推导其关于网络参数U、W、V的梯度表达式(总损失的梯度则是各项相加的过程而已)。
  为了简化符号表示,记 E3=logp(w3|w1,w2) ,则根据RNN的展开图可得,

s3=tanh(U×x3+W×s2)  s2=tanh(U×x2+W×s1)s1=tanh(U×x1+W×s0)  s0=tanh(U×x0+W×s1)(1)

  所以,

s3W=s3W1+s3s2×s2Ws2W=s2W1+s2s1×s1Ws1W=s1W0+s1s0×s0Ws0W=s0W1(2)

  说明一下,为了更好地体现复合函数求导的思想,公式(2)中引入了变量 W1 ,可以把 W1 看作关于W的函数,即 W1=W 。另外,因为 s1 表示RNN网络的初始状态,为一个常数向量,所以公式(2)中第4个表达式展开后只有一项。所以由公式(2)可得,

s3W=s3W1+s3s2×s2W1+s3s2×s2s1×s1W1+s3s2×s2s1×s1s0×s0W1(3)

  简化得下式,

s3W=s3W1+s3s2×s2W1+s3S1×s1W1+s3s0×s0W1(4)

  继续简化得下式,

s3W=i=03s3si×siW1(5)

3.1 E3 关于参数V的偏导数

  记t=3时刻的softmax神经元的输入为 a3 ,输出为 y3 ,网络的真实标签为 y(1)3 。根据函数求导的“链式法则”,所以有下式成立,

E3V=E3a3×a3V=(y(1)3y3)s3(6)

3.2 E3 关于参数W的偏导数

  关于参数W的偏导数,就要使用到上面关于复合函数的推导过程了,记 zi 为t=i时刻隐藏层神经元的输入,则具体的表达式简化过程如下,

E3W=E3s3×s3W=E3a3×a3s3×s3W=k=03E3a3×a3s3×s3sk×skW1=k=03E3a3×a3s3×s3sk×skzk×zkW1=k=03E3zk×zkw1(7)

  类似于标准的BP算法中的表示,定义 δmn=Emzn ,那么可以得到如下递推公式,

δ32=E3z3×z3z2=E3z3×z3s2×s2z2=(δ33W)(1s22)(8)

  那么,公式(7)可以转化为下式,

E3W=k=03δ3k×zkw1(9)

  显然,结合公式(8)中的递推公式,可以递推求解出公式(9)中的每一项,那么 E3 关于参数W的偏导数便迎刃而解了。

3.3 E3 关于参数U的偏导数

  关于参数U的偏导数求解过程,跟W的偏导数求解过程非常类似,在这里就不介绍了,感兴趣的读者可以结合3.2的思路尝试着自己推导一下。

4. 梯度消失问题

  当网络层数增多时,在使用BP算法求解梯度时,自然而然地就会出现“vanishing gradient“问题(还有一种称作“exploding gradient”,但这种情况在训练模型过程中易于被发现,所以可以通过人为控制来解决),下面我们从数学的角度来证明RNN网络确实存在“vanishing gradient“问题,推导公式如下,

E3W=k=03E3a3×a3s3×s3sk×skW1=k=03E3a3×a3s3×(i=k+13sisi1)×skW1(10)

  大家应该注意到了,上面的式子中有一个连乘式,对于其中的每一项,满足 si=activation(U×xi+W×si1) ,当激活函数为tanh时, sisi1 的取值范围为[0, 1]。当激活函数为sigmoid时, sisi1 的取值范围为[0, 1/4](简单的一元函数求导,这里就不展开了)。因为这里我们选择t=3时刻的输出损失,所以连乘的式子的个数并不多。但是我们可以设想一下,对于深度的网络结构而言,若选择tanh或者sigmoid激活函数,对于公式(10)中k取值较小的那一项,一定满足 3i=k+1sisi1 趋近于0,从而导致了消失梯度问题。

  我们再从直观的角度来理解一下消失梯度问题,对于RNN时刻T的输出,其必定是时刻t=1,…,T-1的输入综合作用的结果,也就是说更新模型参数时,要充分利用当前时刻以及之前所有时刻的输入信息。但是如果发生了”消失梯度”问题,就会意味着,距离当前时刻非常远的输入数据,不能为当前模型参数的更新做贡献,所以在RNN的编程实现中,才会有“truncated gradient”这一概念,“截断梯度”就是在更新参数时,只利用较近的时刻的序列信息,把那些“历史悠久的信息”忽略掉了。

  解决“消失梯度问题”,我们可以更换激活函数,比如采用Relu(rectified linear units)激活函数,但是更好的办法是使用LSTM或者GRU架构的网络。

LSTM网络

  为了解决原始RNN网络结构存在的“vanishing gradient”问题,前辈们设计了LSTM这种新的网络结构。但从本质上来讲,LSTM是一种特殊的循环神经网络,其和RNN的区别在于,对于特定时刻t,隐藏层输出 st 的计算方式不同。故对LSTM网络的训练的思路与RNN类似,仅前向传播关系式不同而已。值得一提的是,在对LSTM网络进行训练时,cell state c[0]和hidden state s[0]都是随机初始化得到的。
  GRU(Gated Recurrent Unit)是2014年提出来的新的RNN架构,它是简化版的LSTM,在超参数(hyper-parameters)均调优的前提下,这两种RNN架构的性能相当,但是GRU架构的参数少,所以需要的训练样本更少,易于训练。LSTM和GRU架构的网络图如下,

  关于LSTM网络结构相关的理论,请参见http://colah.github.io/posts/2015-08-Understanding-LSTMs/,相信也只有这样的大牛能把LSTM解析的如此浅显易懂。这里还需要补充说明一下,关于LSTM网络的参数求偏微分,如果我们手动求解的话,也是跟RNN类似的思路,但由于LSTM网络结构比较复杂,手动算的话,式子会变得非常复杂,我们便可以借助深度学习框架的自动微分功能了,现在的框架也都支持自动微分的,比如theano、tensorflow等。

参考资料:http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/
     “A tutorial on training recurrent neural networks”. H. Jaeger, 2002.

猜你喜欢

转载自blog.csdn.net/diligent_321/article/details/53365621