RNN训练详解,通俗易懂

  1. Xt代表输入序列中的第t步元素,例如语句中的一个汉字。一般使用一个one-hot向量来表示,向量的长度是训练所用的汉字的总数(或称之为字典大小),而唯一为1的向量元素代表当前的汉字。

  2. St代表第t步的隐藏状态,其计算公式为St=tanh(U*Xt+W*St-1)。也就是说,当前的隐藏状态由前一个状态和当前输入计算得到。考虑每一步隐藏状态的定义,可以把St视为一块内存,它保存了之前所有步骤的输入和隐藏状态信息。S-1是初始状态,被设置为全0

  3. Ot是第t步的输出。可以把它看作是对第t+1步的输入的预测,计算公式为:Ot=softmax(V*St)。可以通过比较OtXt+1之间的误差来训练模型

  4. U,V,W是RNN的参数,并且在展开之后的每一步中依然保持不变。这就大大减少了RNN中参数的数量。

 

假设真实的输出应该是,那么误差可以定义为,是训练样本的index。整个网络的误差

我们将RNN再放大一些,看看细节

 

矩阵向量化表示

所以梯度为:

其中是点乘符号,即对应元素乘。

简单点来说,RNN的训练过程:假设一个输入文本长度为20,计算t=20时的的loss,然后对loss求导(W,U,V),由于是前后相互影响的,整个求导是一个叠加的过程,即可得到求导后的变化量,整个UVW是共享的。

(综合了多位分享者的内容,具体是谁由于没保存记录所以就没加上,能解决一点疑惑就好)

 

猜你喜欢

转载自blog.csdn.net/cuipanguo/article/details/82144198