cs224n学习笔记L7: 梯度消失和高级RNN

一、梯度消失及爆炸

1.1 RNN中的梯度消失(推导)

如图,反向传播更新隐藏层的向量时,如果途中的梯度较小,链式法则将使得远处的梯度信号

在这里插入图片描述
根据RNN中的隐藏层计算公式:
h ( t ) = σ ( W h h ( t 1 ) + W x x ( t ) + b 1 ) \boldsymbol { h } ^ { ( t ) } = \sigma \left( \boldsymbol { W } _ { h } \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { W } _ { x } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { 1 } \right)
反向传播时计算相邻两个time-step的梯度:
h ( t ) h ( t 1 ) = diag ( σ ( W h h ( t 1 ) + W x x ( t ) + b 1 ) ) W h \frac { \partial \boldsymbol { h } ^ { ( t ) } } { \partial \boldsymbol { h } ^ { ( t - 1 ) } } = \operatorname { diag } \left( \sigma ^ { \prime } \left( \boldsymbol { W } _ { h } \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { W } _ { x } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { 1 } \right) \right) \boldsymbol { W } _ { h }
计算计算不相邻两个time-step的梯度:
J ( i ) ( θ ) h ( j ) = J ( i ) ( θ ) h ( i ) j < t i h ( t ) h ( t 1 ) = J ( i ) ( θ ) h ( i ) W h ( i j ) j < t i diag ( σ ( W h h ( t 1 ) + W x x ( t ) + b 1 ) ) \begin{aligned} \frac { \partial J ^ { ( i ) } ( \theta ) } { \partial h ^ { ( j ) } } & = \frac { \partial J ^ { ( i ) } ( \theta ) } { \partial h ^ { ( i ) } } \prod _ { j < t \leq i } \frac { \partial h ^ { ( t ) } } { \partial h ^ { ( t - 1 ) } } \\ & = \frac { \partial J ^ { ( i ) } ( \theta ) } { \partial h ^ { ( i ) } } { W _ { h } ^ { ( i - j ) } } \prod _ { j < t \leq i } \operatorname { diag } \left( \sigma ^ { \prime } \left( W _ { h } h ^ { ( t - 1 ) } + W _ { x } x ^ { ( t ) } + b _ { 1 } \right) \right) \end{aligned}
考虑矩阵的L2范数:
J ( i ) ( θ ) h ( j ) J ( i ) ( θ ) h ( i ) W h ( i j ) j < t i diag ( σ ( W h h ( t 1 ) + W x x ( t ) + b 1 ) ) \left\| \frac { \partial J ^ { ( i ) } ( \theta ) } { \partial \boldsymbol { h } ^ { ( j ) } } \right\| \leq \left\| \frac { \partial J ^ { ( i ) } ( \theta ) } { \partial \boldsymbol { h } ^ { ( i ) } } \right\| \| \boldsymbol { W } _ { h } \| ^ { ( i - j ) } \prod _ { j < t \leq i } \| \operatorname { diag } \left( \sigma ^ { \prime } \left( \boldsymbol { W } _ { h } \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { W } _ { x } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { 1 } \right) \right) \|
有论文证明了矩阵 W h W_h 的特征值如果小于1,那么上面的梯度就会随着距离指数级减小。由于我们使用sigmod非线性激活函数,边界值就是1。(个人觉得存疑,因为sigmod并没有对 W h W_h 使用)。
相反的,如果其特征值大于一,能被证明会发生梯度爆炸。

1.2 梯度消失会带来的问题

梯度消失会使后面距离较远的损失函数很难影响到前面的隐藏层的值,模型就无法有效的学习到前后距离较远的time-step的相互关联。因此,梯度可以看做前面的时间步对未来预测的影响。
在这里插入图片描述
对于较长的距离,例如从t到t+n,由于没有足够的误差信号传播来更新,预测阶段就不能获得正确的通过t预测t+n的数据的参数,这两者的预测数据就几乎没有关联。
一个容易出错的例子:
在这里插入图片描述

1.3 梯度爆炸带来的问题及解决办法

根据梯度更新的公式,如果梯度太大,一次更新步长太大,导致无法收敛,参数最终会溢出(inf),如果模型训练时loss异常增大,很可能是由于梯度爆炸导致(可以返回上一个检查点重新训练)。


下面这个图很直观的表现了梯度爆炸的可能原因,梯度陡崖的存在导致参数向它的反方向更新了很大一个步长。
在这里插入图片描述
可以通过梯度裁剪的方式来避免梯度爆炸。即设定一个梯度阈值,每次进行SGD更新前检查梯度,如果大于这个阈值,就对所有梯度进行等比例缩小。(相当于沿着同一个方向减小梯度,但是减小了步长)
在这里插入图片描述

二、更复杂的RNN

2.1 Long Short-Term Memory(LSTM)

2.1.1 LSTM的结构

梯度爆炸可以通过梯度裁剪来解决,而LSTM就是1997年提出的解决梯度消失的办法。其基本思想是通过在每个time-step增加一个记忆区来储存上下文的信息。
LSTM的基本结构:

  • 每个time-step有一个隐藏状态 h ( t ) h^{(t)} (hidden state)和一个cell状态 c ( t ) c^{(t)} (cell state)
  • 这两个都是长度为n的向量。
  • 记忆状态储存长距离的依赖信息。
  • LSTM对cell可以有三种操作:擦除、重写、读取。

LSTM中的门:

  • 通过三个关联的门(gates)来控制上面三种操作的选择。
  • 这三个门也分别是三个长度为n的向量。
  • 在每一个time-step的gates有开(1)、闭(0)两种状态或者这之间的某个值
  • 门的状态在每一个time-step受上下文信息影响动态调整

2.1.2 LSTM前向传播

对于输入序列 x ( t ) x^{(t)} :

  1. 每一个门是一个长为n的向量,都有一个对应权重矩阵。
    ( f o r g e t ) f ( t ) = σ ( W f h ( t 1 ) + U f x ( t ) + b f ) ( i n p u t ) i ( t ) = σ ( W i h ( t 1 ) + U i x ( t ) + b i ) ( o u t p u t ) o ( t ) = σ ( W o h ( t 1 ) + U o x ( t ) + b o ) \begin{aligned} (遗忘门-forget)\boldsymbol { f } ^ { ( t ) } & = \sigma \left( \boldsymbol { W } _ { f } \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { U } _ { f } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { f } \right) \\ (输入门-input) \boldsymbol { i } ^ { ( t ) } & = \sigma \left( \boldsymbol { W } _ { i } \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { U } _ { i } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { i } \right) \\ (输出门-output)\boldsymbol { o } ^ { ( t ) } & = \sigma \left( \boldsymbol { W } _ { o } \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { U } _ { o } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { o } \right) \end{aligned}
其中三个门控制的效果为均是以cell记忆单元命名:
遗忘门: c ( t 1 ) c^{(t-1)} c ( t ) c^{(t)} 的信息删除
输入门: h ( t 1 ) h^{(t-1)} c ( t ) c^{(t)} 的信息输入
输出门: c ( t ) c^{(t)} h ( t ) h^{(t)} 的信息输出
  1. 通过这些门来控制cell,进而计算hidden ( a b \boldsymbol a \circ \boldsymbol b 在这里是指元素级element-wise相乘)
    c ~ ( t ) = tanh ( W c h ( t 1 ) + U c x ( t ) + b c ) c ( t ) = f ( t ) c ( t 1 ) + i ( t ) c ~ ( t ) h ( t ) = o ( t ) tanh c ( t ) \begin{aligned} \tilde { \boldsymbol { c } } ^ { ( t ) } & = \tanh \left( \boldsymbol { W } _ { c } \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { U } _ { c } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { c } \right) \\ \boldsymbol { c } ^ { ( t ) } & = \boldsymbol { f } ^ { ( t ) } \circ \boldsymbol { c } ^ { ( t - 1 ) } + \boldsymbol { i } ^ { ( t ) } \circ \tilde { \boldsymbol { c } } ^ { ( t ) } \\ \boldsymbol { h } ^ { ( t ) } & = \boldsymbol { o } ^ { ( t ) } \circ \tanh \boldsymbol { c } ^ { ( t ) } \end{aligned}
    注意这里的tanh是双曲正切函数,值域为(-1, 1)现在在放出这个经典的图,注意左下角的图例,这次终于能看懂了:
    在这里插入图片描述

2.1.3 LSTM为什么解决了梯度消失的问题

LSTM事实上并没有从根本上解决梯度消失的问题,但它通过记忆之前的信息,从而产生了长距离的依赖。

那么为什么普通RNN我们通过梯度消失断定其依赖距离存在问题,而LSTM就只通过前向传播断定其有效呢?老师的回答是LSTM的记忆机制可能提供一种梯度的传播捷径,而普通RNN的hidden是其唯一的传播途径。

2.1.4 LSTM的发展历程

  • 2013-2015 LSTM称为最主要的工具之一
  • 这之后transformer他、取代了LSTM的地位

2.1.5 Bidirectional RNNs

单向的RNN, 某个time-step可能只与它的左侧context产生关联,通过双向RNN拼接,解决这个问题。但要注意双向RNN不是语言建模(用左侧序列预测下一个)
例如:BERT(Bidirectional Encoder Representations from Transformer), 现在来看这个明明就很有意思了,transformer作为encoder, 并且是双向的两层。
在这里插入图片描述
在这里插入图片描述

2.1.6 muti-layer(stacked) RNNs

RNN已经在序列单个维度上进行深度抽象,通过增加网络层数,可以在另一个维度上进行高度抽象,从而学习到更为复杂的特征。

2.2 GRU(gated recurrent units)

  • 2014年提出的LSTM的替代,结构更简单
  • 每个time-step有一个输入 x ( t ) x^{(t)} 和一个hidden state h ( t ) h^{(t)}
  • two gates:
    • update gate: u ( t ) = σ ( W u h ( t 1 ) + U u x ( t ) + b u ) u ^ { ( t ) } = \sigma \left( \boldsymbol { W } _ { u } \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { U } _ { u } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { u } \right)
    • reset gate: r ( t ) = σ ( W r h ( t 1 ) + U r x ( t ) + b r ) r ^ { ( t ) } = \sigma \left( \boldsymbol { W } _ { r } \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { U } _ { r } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { r } \right)
  • 门控隐藏层:
    h ~ ( t ) = tanh ( W h ( r ( t ) h ( t 1 ) ) + U h x ( t ) + b h ) h ( t ) = ( 1 u ( t ) ) h ( t 1 ) + u ( t ) h ~ ( t ) \begin{array} { l } \tilde { \boldsymbol { h } } ^ { ( t ) } = \tanh \left( \boldsymbol { W } _ { h } \left( \boldsymbol { r } ^ { ( t ) } \circ \boldsymbol { h } ^ { ( t - 1 ) } \right) + \boldsymbol { U } _ { h } \boldsymbol { x } ^ { ( t ) } + \boldsymbol { b } _ { h } \right) \\ \boldsymbol { h } ^ { ( t ) } = \left( 1 - \boldsymbol { u } ^ { ( t ) } \right) \circ \boldsymbol { h } ^ { ( t - 1 ) } + \boldsymbol { u } ^ { ( t ) } \circ \tilde { \boldsymbol { h } } ^ { ( t ) } \end{array}

2.3 GRU vs LSTM

  • RNN有很多种,但这两种使用最为广泛
  • 两者唯一的区别:GRU比LSTM运算更快
  • 当数据充足时,建议使用LSTM(因为参数更多,但依然是玄学)
  • 需要性能的话就GRU

三、梯度消失与爆炸广泛存在

  • 梯度消失爆炸存在于所有神经网络中,尤其是深度神经网络,这是由于链式法则产生。
  • 这导致底部的网络层更难以训练
  • 网络设置直接相连的捷径,使梯度能够有效传播,例如ResNet
    在这里插入图片描述
  • DenseNet: 直接将底层与上面的每一层暴力相连: 在这里插入图片描述
  • HighwayNet: 受LSTM启发,网络结构与resnet相似,但捷径部分是一个动态的门。

四、应用技巧

  1. 通常多层RNN表现更好, 例如Britz在2017的论文发现2-4层的RNN作为encoder, 4层的RNN作为decoder用于神经网络翻译效果最好。但是深度的网络需要使用skip-connections/dense-connections来训练更深的RNN.
  2. 基于Transformer的网络可能深达24层(BERT),其中就用到了很多skip-connections的技术。
  3. LSTM更强,但GRU更快
  4. 梯度裁剪
  5. RNN能用双向就用双向
  6. 多层RNN更强大,但可能需要用到skip/dense-connections的技巧。
发布了24 篇原创文章 · 获赞 8 · 访问量 7202

猜你喜欢

转载自blog.csdn.net/geek_hch/article/details/104543068
今日推荐