通过时间反向传播

通过时间反向传播

本文基于《动手学深度学习》一书,给出了对应章节相对详细的推导。

一、RNN的反向传播推导

1.问题描述

这是RNN网络的t时刻的关系式:
{ h t = W h x x t + W h h h t − 1 O t = W q h h t \left \{ \begin{array}{ll} h_t = W_{hx}x_t + W_{hh}h_{t-1} \\ O_t = W_{qh}h_t \\ \end{array} \right . { ht=Whxxt+Whhht1Ot=Wqhht

设有损失函数
L = 1 T ∑ t = 1 n l ( O t , y t ) L = \frac{1}{T}\sum_{t=1}^{n}l(O_t, y_t) L=T1t=1nl(Ot,yt)

欲求
∂ L ∂ W q h , ∂ L ∂ W h x , ∂ L ∂ W h h \frac{\partial L}{\partial W_{qh}}, \frac{\partial L}{\partial W_{hx}}, \frac{\partial L}{\partial W_{hh}} WqhL,WhxL,WhhL
一些准备: 矩阵的链式求导和基本求导法则与原理是需要掌握的

2.问题求解

首先,求解 ∂ L ∂ W q h \frac{\partial L}{\partial W_{qh}} WqhL
对于任意时刻 t t t ,显然有:
∂ L ∂ O t = 1 T ⋅ ∂ l ( O t , y t ) ∂ O t d l = t r ( ( ∂ l ∂ O t ) T ⋅ d O t ) O t = W q h h t \frac{\partial L}{\partial O_t} = \frac{1}{T} \cdot \frac{\partial l(O_t, y_t)}{\partial O_t} \\ \mathrm{d}l = tr\left( {\left( \frac{\partial l}{\partial O_t} \right)}^T \cdot \mathrm{d}O_t \right) \\ O_t = W_{qh}h{t} OtL=T1Otl(Ot,yt)dl=tr((Otl)TdOt)Ot=Wqhht
因此,将 O t O_t Ot 带入微分式中,有:
d L = t r ( ∑ i = 1 T ( ∂ l ∂ O t ) T d W q h ⋅ h t ) \mathrm{d}L = tr\left( \sum_{i=1}^{T}{\left( \frac{\partial l}{\partial O_t} \right)}^T \mathrm{d}W_{qh} \cdot h_t \right) dL=tr(i=1T(Otl)TdWqhht)
h t h_t ht 放到迹的右方,有:
d L = t r ( ∑ i = 1 T h t ( ∂ l ∂ O t ) T d W q h ) \mathrm{d}L = tr\left( \sum_{i=1}^{T}h_t{\left( \frac{\partial l}{\partial O_t} \right)}^T \mathrm{d}W_{qh} \right) dL=tr(i=1Tht(Otl)TdWqh)
因此:
∂ L ∂ W q h = ( ∑ i = 1 T h t ( ∂ l ∂ O t ) T ) T = ∑ i = 1 T ∂ l ∂ O t ( h t ) T \frac{\partial L}{\partial W_{qh}} = \left( \sum_{i=1}^{T}h_t{\left( \frac{\partial l}{\partial O_t} \right)}^T \right)^T = \sum_{i=1}^{T} \frac{\partial l}{\partial O_t} {\left( h_t \right)}^T WqhL=(i=1Tht(Otl)T)T=i=1TOtl(ht)T
接下来我们尝试求解 ∂ L ∂ W h x , ∂ L ∂ W h h \frac{\partial L}{\partial W_{hx}},\frac{\partial L}{\partial W_{hh}} WhxL,WhhL
先从T时刻开始求解(这里的prod()表示了矩阵链式求导的法则):
我们首先有:
{ h t = W h x x t + W h h h t − 1 O t = W q h h t \left \{ \begin{array}{ll} h_t = W_{hx}x_t + W_{hh}h_{t-1} \\ O_t = W_{qh}h_t \\ \end{array} \right . { ht=Whxxt+Whhht1Ot=Wqhht

∂ L ∂ h T = p r o d ( ∂ L ∂ O T , ∂ O T ∂ h T ) \frac{\partial L}{\partial h_T} = prod\left( \frac{\partial L}{\partial O_T}, \frac{\partial O_T}{\partial h_T} \right) hTL=prod(OTL,hTOT)
对于T-1时刻,有
∂ L ∂ h T − 1 = p r o d ( ∂ L ∂ O T − 1 , ∂ O T − 1 ∂ h T − 1 ) + p r o d ( ∂ L ∂ h T , ∂ h T ∂ h T − 1 ) \frac{\partial L}{\partial h_{T-1}} = prod\left( \frac{\partial L}{\partial O_{T-1}}, \frac{\partial O_{T-1}}{\partial h_{T-1}} \right) + prod\left( \frac{\partial L}{\partial h_T}, \frac{\partial h_T}{\partial h_{T-1}} \right) hT1L=prod(OT1L,hT1OT1)+prod(hTL,hT1hT)

同理,对于t时刻, t < T,有:
∂ L ∂ h t = p r o d ( ∂ L ∂ O t , ∂ O t ∂ h t ) + p r o d ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) \frac{\partial L}{\partial h_t} = prod\left( \frac{\partial L}{\partial O_t}, \frac{\partial O_t}{\partial h_t} \right) + prod\left( \frac{\partial L}{\partial h_{t+1}}, \frac{\partial h_{t+1}}{\partial h_t} \right) htL=prod(OtL,htOt)+prod(ht+1L,htht+1)
求偏导方式如上求解 ∂ L ∂ W q h \frac{\partial L}{\partial W_{qh}} WqhL 时使用的 化矩阵迹链式求导方法 所示,得到:
∂ L ∂ h t = W q h T ∂ L ∂ O t + W h h T ∂ L ∂ h t + 1 \frac{\partial L}{\partial h_t} = W_{qh}^T \frac{\partial L}{\partial O_t} + W_{hh}^T \frac{\partial L}{\partial h_{t+1}} htL=WqhTOtL+WhhTht+1L
打开该递归公式可得:
∂ L ∂ h t = ∑ i = t T ( W h h T ) T − i W q h T ∂ L ∂ O T + t − i \frac{\partial L}{\partial h_t} = \sum_{i=t}^T \left( W_{hh}^T \right)^{T-i} W_{qh}^T \frac{\partial L}{\partial O_{T+t-i}} htL=i=tT(WhhT)TiWqhTOT+tiL
所以
∂ L ∂ W h x = p r o d ( ∂ L ∂ h t , ∂ h t ∂ W h x ) ∂ L ∂ W h h = p r o d ( ∂ L ∂ h t , ∂ h t ∂ W h h ) \frac{\partial L}{\partial W_{hx}} = prod\left( \frac{\partial L}{\partial h_t}, \frac{\partial h_t}{\partial W_{hx}} \right) \\ \frac{\partial L}{\partial W_{hh}} = prod\left( \frac{\partial L}{\partial h_t}, \frac{\partial h_t}{\partial W_{hh}} \right) WhxL=prod(htL,Whxht)WhhL=prod(htL,Whhht)
继而有(此处的prod链式法则同上,请自行计算):
∂ L ∂ W h x = ∑ t = 1 T ∂ L ∂ h t x t T ∂ L ∂ W h h = ∑ t = 1 T ∂ L ∂ h t h t − 1 T \frac{\partial L}{\partial W_{hx}} = \sum_{t=1}^T\frac{\partial L}{\partial h_t}x_t^T \\ \frac{\partial L}{\partial W_{hh}} = \sum_{t=1}^T\frac{\partial L}{\partial h_t}h_{t-1}^T WhxL=t=1ThtLxtTWhhL=t=1ThtLht1T
再加上之前求解的:
∂ L ∂ W q h = ∑ i = 1 T ∂ l ∂ O t ( h t ) T \frac{\partial L}{\partial W_{qh}} = \sum_{i=1}^{T} \frac{\partial l}{\partial O_t} {\left( h_t \right)}^T WqhL=i=1TOtl(ht)T
至此RNN的反向传播推导完毕。

二、LSTM的反向传播推导

1.问题描述

I t = σ ( W x i X t + W h i H t − 1 + b i ) F t = σ ( W x f X t + W h f H t − 1 + b f ) O t = σ ( W x o X t + W h o H t − 1 + b o ) C t ′ = t a n h ( W x c X t + W h c H t − 1 + b c ) C t = F t ⊙ C t − 1 + I t ⊙ C t ′ H t = O t ⊙ t a n h ( C t ) Y t = W q h H t + b q \begin{array}{ll} I_t=\sigma\left( W_{xi}X_t + W_{hi}H_{t-1} + b_i \right) \\ F_t=\sigma\left( W_{xf}X_t + W_{hf}H_{t-1} + b_f \right) \\ O_t=\sigma\left( W_{xo}X_t + W_{ho}H_{t-1} + b_o \right) \\ C_t^{'}=\mathrm{tanh}\left( W_{xc}X_t + W_{hc}H_{t-1} + b_{c} \right) \\ C_t=F_t \odot C_{t-1} + I_t \odot C_t^{'} \\ H_t=O_t \odot \mathrm{tanh}(C_t) \\ Y_t=W_{qh}H_t + b_q \end{array} It=σ(WxiXt+WhiHt1+bi)Ft=σ(WxfXt+WhfHt1+bf)Ot=σ(WxoXt+WhoHt1+bo)Ct=tanh(WxcXt+WhcHt1+bc)Ct=FtCt1+ItCtHt=Ottanh(Ct)Yt=WqhHt+bq

2.问题求解

猜你喜欢

转载自blog.csdn.net/qq_41624557/article/details/117304654