【深度学习】RNN的梯度消失/爆炸与正交初始化

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/shenxiaolu1984/article/details/71508892

在训练较为复杂的RNN类型网络时,有时会采取正交方法初始化(orthogonal initialization)网络参数。本文用简单的例子介绍其中的原因。

本文较大程度参考了这篇博客

简单例子

RNN具有如下形式:

ht=fh(Wht1+Vxt)

yt=fy(Uht)

我们考虑一个极端简化的版本:没有输入,激活函数为直通,直接输出隐变量。

yt=Wyt1

计算第t步的输出时,需要计算参数矩阵的t次幂:

yt=Wty0

为了计算简便,可以把方阵 W 进行正交分解:
W=QΛQ1

yt=QΛtQ1y0

其中 Q 是单位正交矩阵; Λ 是对角阵,计算其t次幂只需要把对角线上的特征值进行幂运算即可。

优化网络参数时,使用简单的二范数代价:

E=||ytyt¯¯¯||2

为了更新参数,需要计算代价对于参数的导数(是个标量):

EWi=2(ytyt¯¯¯)TytWi

梯度消失/爆炸

当RNN步数t增加时, yt/Wi 会怎样变化呢?

为书写直观假设 y 是个二维向量。于是 W 有四个参数,我们用正交分解的形式表示出来。

Q=[w1w2w2w1],Q1=[w1w2w2w1]

Λ=diag(w3,w4)

可以直接写出 yt 的表达式(善用Matlab的syms功能):

yt=[w21wt3+w22wt4w1w2(wt4wt3)w1w2(wt4wt3)w21wt4+w22wt3]y0

分别写出对四个参数的导数(长度为2的矢量):

ytw1=[2w1wt3w2(wt4wt3)w2(wt4wt3)2w1wt4]y0

ytw2=[2w2wt4w1(wt4wt3)w1(wt4wt3)2w2wt3]y0

ytw3=[tw21wt13w1w2wt13tw1w2wt13tw22wt13]y0

ytw4=[tw22wt14w1w2wt14tw1w2wt14tw21wt14]y0

重点:每一项里都有 w3,w4 的t或t-1次幂。不考虑细节,这个推导说明:

代价对于参数的导数 参数矩阵特征值 λi 的t次方。

如果 |λi|>1 ,则步数增加时 λt 超出浮点范围,发生梯度爆炸,优化无法收敛;
如果 |λi|<1 ,步数增加时 λt 变为0,发生梯度消失,优化停滞不前。

正交初始化

理想的情况是,特征值绝对值为1。则无论步数增加多少,梯度都在数值计算的精度内。

这样的参数矩阵 W 单位正交阵

把转移矩阵初始化为单位正交阵,可以避免在训练一开始就发生梯度爆炸/消失现象,称为orthogonal initialization。

其他解决方法

除了正交初始化,在RNN类型网络训练中,还可以使用如下方法解决梯度消失/爆炸问题:
- 使用ReLU激活函数->解决梯度消失
- 对梯度进行剪切(gradient clipping)->解决梯度爆炸
- 引入更复杂的结构,例如LSTM、GRU->解决梯度消失

猜你喜欢

转载自blog.csdn.net/shenxiaolu1984/article/details/71508892