理解RNN梯度消失和弥散以及LSTM为什么能解决

根据RNN的BPTT推导,我们可以得到下面的式子:


而又有:


因此,每一个Sj对Sj-1的偏导都等于tanh‘(..)*W

注意到:tanh的梯度最大只能是1,而更多时候都是趋近于0的饱和状态,当求导得到的jacobian矩阵存在一些元素趋近于0,多个矩阵相乘,会使得梯度很快消失。这时候有人会问,为什么不将tanh换成ReLU呢?这样不就可以解决梯度消失了吗?

确实,换成ReLU在一定程度上可以解决梯度消失的问题,但是:

那为什么同样的方法在RNN中不奏效呢?其实这一点Hinton在它的IRNN论文里面(arxiv:[1504.00941] A Simple Way to Initialize Recurrent Networks of Rectified Linear Units)是很明确的提到的:

也就是说在RNN中直接把激活函数换成ReLU会导致非常大的输出值。

一方面,将tanh换成ReLU,最后计算的结果会变成多个W连乘,如果W中存在特征值>1的,那么经过BPTT连乘后得到的值会爆炸,产生梯度爆炸的问题,使得RNN仍然无法传递较远距离。


因此,RNN没办法BPTT较远的距离的原因就是如此。

那么怎么解决呢?这就引入了LSTM的思想了:

如果对于所有的 m

|w_{l_{m}l_{m-1}}y_{lm}'(t-m)|>1.0

则梯度会随着反向传播层数的增加而呈指数增长,导致梯度爆炸。

如果对于所有的 m

|w_{l_{m}l_{m-1}}y_{lm}'(t-m)|<1.0

则在经过多层的传播后,梯度会趋向于0,导致梯度弥散(消失)。

Sepp Hochreiter 和 Jürgen Schmidhuber 在他们提出 Long Short Term Memory 的文章里讲到,为了避免梯度弥散和梯度爆炸,一个 naive 的方法就是强行让 error flow 变成一个常数:

|y_{jj}(t)'w_{jj}|=1.0

w_{jj} 就是RNN里自己到自己的连接。他们把这样得到的模块叫做CEC(constant error carrousel),很显然由于上面那个约束条件的存在,这个CEC模块是线性的。这就是LSTM处理梯度消失的问题的动机。

通俗地讲:RNN中,每个记忆单元h_t-1都会乘上一个W和激活函数的导数,这种连乘使得记忆衰减的很快,而LSTM是通过记忆和当前输入"相加",使得之前的记忆会继续存在而不是受到乘法的影响而部分“消失”,因此不会衰减。但是这种naive的做法太直白了,实际上就是个线性模型,在学习效果上不够好,因此LSTM引入了那3个门:

作者说所有“gradient based”的方法在权重更新都会遇到两个问题:

input weight conflict 和 output weight conflict

大意就是对于神经元的权重\bm{w} ,不同的数据(\bm{x_i}, \ \bm{y_i}) 所带来的更新是不同的,这样可能会引起冲突(比如有些输入想让权重变小,有些想让它变大)。网络可能需要选择性地“忘记”某些输入,以及“屏蔽”某些输出以免影响下一层的权重更新。为了解决这些问题就提出了“门”。

举个例子:在英文短语中,主语对谓语的状态具有影响,而如果之前同时出现过第一人称和第三人称,那么这两个记忆对当前谓语就会有不同的影响,为了避免这种矛盾,我们希望网络可以忘记一些记忆来屏蔽某些不需要的影响。

因为LSTM对记忆的操作是相加的,线性的,使得不同时序的记忆对当前的影响相同,为了让不同时序的记忆对当前影响变得可控,LSTM引入了输入门和输出门,之后又有人对LSTM进行了扩展,引入了遗忘门。

总结一下:LSTM把原本RNN的单元改造成一个叫做CEC的部件,这个部件保证了误差将以常数的形式在网络中流动 ,并在此基础上添加输入门和输出门使得模型变成非线性的,并可以调整不同时序的输出对模型后续动作的影响。










猜你喜欢

转载自blog.csdn.net/hx14301009/article/details/80401227
今日推荐