循环神经网络(RNN)原理解析

对于具有时间维度的数据,比如阅读的文本、说话时发出的语音信号、随着时间变化的股市参数等,这类数据并不一定具有局部相关性,同时数据在时间维度上的长度也是可变的。

这一特性导致一般的神经网络难以处理,而循环神经网络则以序列数据为输入,在序列的演进方向进行递归且所有节点按链式连接,具有记忆性、参数共享并且图灵完备,因此在对序列的非线性特征进行学习时具有一定优势,可以很好地解决问题。

下面就来详细说明其原理。


RNN的结构

基础的神经网络只在层与层之间建立了权连接,而 RNN 则在此基础上在层之间的神经元之间也建立了权连接,如图:
在这里插入图片描述
基础的神经网络没有 a 加入,只有从 x 到 y 经历几层网络的过程,每个独立的 x 输入之间互不影响;而 RNN 通过添加 a 将每个 x 联系起来,因此它能产生记忆性并找到序列的规律,这样就可以在时间维度上处理变化的数据。

图中每个箭头代表做一次函数运算,长方形内包含4个圆的图形代指神经网络的运算过程。

上面是 RNN 的标准结构,然而在实际中这一种结构并不能解决所有问题。比如电影评论情感分类,输入的是一段文字,输出的却是一个表示类型的数,那就只需要单个输出,如图:
在这里插入图片描述
还有一对多的结构,比如输入一个主题生成一整篇文章,如图:
在这里插入图片描述
最后就是常规的多对多结构,如果输入序列的大小和输出序列的大小一样,如图:
在这里插入图片描述
如果输入序列的大小和输出序列的大小不一样,比如机械翻译出的句子和原句子长度一般不一样,如图:
在这里插入图片描述
当然这些是最基本的结构,针对不同情况 RNN 还有很多不一样的特殊结构。


运算流程

上面说过第一张图中每个箭头表示一个函数运算过程,那么这里就解说一下每个函数及其运作情况。

除了第一个 a 要初始化外,后面其它 a 都要在前一个 a 和 x 的基础上进行运算:
a < t > = g ( W a a a < t 1 > + W a x x < t > + b a ) a^{<t>}=g(W_{aa}a^{<t-1>}+W_{ax}x^{<t>}+b_a)

其中 g ( ) g() 是激活函数; W a a W_{aa} W a x W_{ax} 是权值,两个权值为所有 a 和 x 共享,不会单独对每个输入设置不同的值; b a b_a 为偏置。

a 计算好之后,就可以以此为基础得出输出了,运算方程:
y ^ < t > = g ( W y a a < t > + b y ) \hat{y}^{<t>}=g(W_{ya}a^{<t>}+b_y)

其中 g ( ) g() 也是激活函数

有时计算过程比较简单,a 可以直接当作输出,即不需要上述变换。


训练方法

RNN 常采用通过时间的反向传播算法(Back-propagation through time),因为 RNN 处理时间序列数据,所以要基于时间反向传播,沿着需要优化的参数的负梯度方向不断寻找更优的点直至收敛,核心还是求各个参数的梯度。

因此在求结果时还要计算损失函数 L(),先前向一步步计算完成,再反向由损失一步步计算权重的偏差并修正。

W y a W_{ya} 的过程相对简单,只关注当前值:
L ( t ) W y a = L ( t ) g ( t ) g ( t ) W y a \frac{∂ L^{(t)}}{∂ W_{ya}} = \frac{∂ L^{(t)}}{∂ g^{(t)}}\frac{∂ g^{(t)}}{∂ W_{ya}}

RNN 的损失会随着时间累加,所以要求出所有时刻的偏导:
L ( t ) W y a = t = 1 n L ( t ) g ( t ) g ( t ) W y a \frac{∂L^{(t)}}{∂W_{ya}} =\sum_{t=1}^n \frac{∂ L^{(t)}}{∂ g^{(t)}}\frac{∂ g^{(t)}}{∂ W_{ya}}

W a a W_{aa} W a x W_{ax} 的计算需要涉及到历史数据,其偏导求起来相对复杂。在某个时刻的偏导数,需要追溯这个时刻之前所有时刻的信息
L ( t ) W a a = k = 1 t L ( t ) g ( t ) g ( t ) a ( t ) ( j = k + 1 t a ( j ) a ( j 1 ) ) a ( k ) W a a \frac{∂L^{(t)}}{∂W_{aa}} =\sum_{k=1}^t \frac{∂ L^{(t)}}{∂ g^{(t)}}\frac{∂ g^{(t)}}{∂ a^{(t)}}(\prod_{j=k+1}^t\frac{∂ a^{(j)}}{∂ a^{(j-1)}})\frac{∂ a^{(k)}}{∂W_{aa}}

L ( t ) W a x = k = 1 t L ( t ) g ( t ) g ( t ) a ( t ) ( j = k + 1 t a ( j ) a ( j 1 ) ) a ( k ) W a x \frac{∂L^{(t)}}{∂W_{ax}} =\sum_{k=1}^t \frac{∂ L^{(t)}}{∂ g^{(t)}}\frac{∂ g^{(t)}}{∂ a^{(t)}}(\prod_{j=k+1}^t\frac{∂ a^{(j)}}{∂ a^{(j-1)}})\frac{∂ a^{(k)}}{∂W_{ax}}

激活函数是嵌套在里面的,如果把激活函数放进去,拿出中间累乘的那部分。

累乘会导致激活函数导数的累乘,进而会导致梯度消失梯度爆炸现象的发生。

改善梯度消失的方法:选取更好的激活函数、改变传播结构

猜你喜欢

转载自blog.csdn.net/weixin_44613063/article/details/106629906