RNN中梯度消失和爆炸的问题公式推导

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u014046022/article/details/83859183

RNN

首先来看一下经典的RRN的结构图,这里 x x 是输入 W W 是权重矩阵 (RNN的权重矩阵是共享的所以都是W) h h 是隐藏状态 y y 是输出
在这里插入图片描述

RNN简单公式定义

h t = W f ( h t 1 ) + W ( h x ) x [ t ] h_t = W*f(h_{t-1}) + W^{(hx)}*x_{[t]}
y t = W ( S ) f ( h t ) y_{t} = W^{(S)}*f(h_t)
其中, h t h_t 表示 t 时刻的隐藏状态 x [ t ] x_{[t]} 表示 t 时刻的输入 y t y_t 表示 t 时刻的输出。我们记总体的error为 E E 那么 E E 有如下表达式:
E = t = 1 T E t W E = \sum_{t=1}^{T}\frac{\partial E_t}{\partial W}
总体的误差是所有时刻 t 的误差的累加。那么继续往下展开, 根据链式法则:
E t W = k = 1 t E t y t y t h t h t h k h k W \frac{\partial E_t}{\partial W} = \sum_{k=1}^{t}\frac{\partial E_t}{\partial y_t} \frac{\partial y_t}{\partial h_t}\frac{\partial h_t}{\partial h_k} \frac{\partial h_k}{\partial W}
继续往下展开有:
h t h k = j = k + 1 t h j h j 1 \frac{\partial h_t}{\partial h_k} = \prod_{j=k+1}^{t}\frac{\partial h_j}{\partial h_{j-1}}
注意到: h t = W f ( h t 1 ) + W ( h x ) x [ t ] h_t = W*f(h_{t-1}) + W^{(hx)}*x_{[t]} ,上式的每个偏导其实是一个Jacobian式

在这里插入图片描述

考虑Jacobians的范数,令:
h j h j 1 W T d i a g [ f ( h j 1 ) ] β w β h ||\frac{\partial h_j}{\partial h_{j-1}} || \leq ||W^{T}|| *||diag[f'(h_{j-1})]|| \leq \beta_w*\beta_h
其中, β w , β h \beta_w ,\beta_h 表示正则化的上界。将上式回代到连乘的式子得:
h t h k = j = k + 1 t h j h j 1 ( β w β h ) t k ||\frac{\partial h_t}{\partial h_k} ||= ||\prod_{j=k+1}^{t}\frac{\partial h_j}{\partial h_{j-1}}|| \leq(\beta_w *\beta_h)^{t-k}
这里得 t 表示 time-step,也就是序列越长t会越大,即就变成了长期依赖的问题。注意到 ( β w β h ) t k (\beta_w *\beta_h)^{t-k} 这项其实与矩阵的W的初始化有关,假设初始化一些非常小的数,W的范数也会变得很小,也就是 β w \beta_w 会变得比较小,那么随着t的增长,这一指数项会趋近于0而导致梯度消失,相反,如果初始化成为大于1的数,则随着t的增长,会导致梯度爆炸。

猜你喜欢

转载自blog.csdn.net/u014046022/article/details/83859183
今日推荐