retropropagación a través del tiempo
Basado en el libro "Hands-On Deep Learning", este artículo ofrece una derivación relativamente detallada de los capítulos correspondientes.
1. Derivación de retropropagación de RNN
1. Descripción del problema
Esta es la expresión relacional de la red RNN en el tiempo t:
{ ht = W hxxt + W hhht − 1 O t = W qhht \left \{ \begin{array}{ll} h_t = W_{hx}x_t + W_{ hh} h_{t-1} \\ O_t = W_{qh}h_t \\ \end{matriz} \right .{
ht=Wh xXt+Wh hht − 1Ot=Wq hht
Supongamos que la función de pérdida
L = 1 T ∑ t = 1 nl ( O t , yt ) L = \frac{1}{T}\sum_{t=1}^{n}l(O_t, y_t)L=T1t = 1∑nyo ( Ot,yt)
欲求
∂ L ∂ W qh , ∂ L ∂ W hx , ∂ L ∂ W hh \frac{\parcial L}{\parcial W_{qh}}, \frac{\parcial L}{\parcial W_{hx}}, \frac{\parcial L}{\parcial W_{hh}}∂ Wq h∂ L,∂ Wh x∂ L,∂ Wh h∂ L
Algunas preparaciones: Es necesario dominar la derivación de cadenas de matriz y las reglas y principios básicos de derivación .
2. Resolución de problemas
Primero , resuelve para ∂ L ∂ W qh \frac{\partial L}{\partial W_{qh}}∂ Wq h∂ L
Para cualquier momento ttt,显然有:
∂ L ∂ O t = 1 T ⋅ ∂ l ( O t , yt ) ∂ O tdl = tr ( ( ∂ l ∂ O t ) T ⋅ re O t ) O t = W qhht \frac{\ L parcial}{\O_t parcial} = \frac{1}{T} \cdot \frac{\l parcial(O_t, y_t)}{\O_t parcial} \\ \mathrm{d}l = tr\left( { \left( \frac{\parcial l}{\parcial O_t} \right)}^T \cdot \mathrm{d}O_t \right) \\ O_t = W_{qh}h{t}∂O _t∂ L=T1⋅∂O _t∂ l ( Ot,yt)dl _=t r( (∂O _t∂ l)T⋅d Ot)Ot=Wq hh t
Por lo tanto,O t O_tOt带入微分式中,有:
re L = tr ( ∑ i = 1 T ( ∂ l ∂ O t ) T re W qh ⋅ ht ) \mathrm{d}L = tr\left( \sum_{i=1} ^{T}{\left( \frac{\parcial l}{\parcial O_t} \right)}^T \mathrm{d}W_{qh} \cdot h_t \right)d L=t r(yo = 1∑T(∂O _t∂ l)Td W_q h⋅ht)
va aht h_thtPóngalo en el lado derecho de la traza:
d L = tr ( ∑ i = 1 T ht ( ∂ l ∂ O t ) T d W qh ) \mathrm{d}L = tr\left( \sum_{i=1 } ^{T}h_t{\left( \frac{\parcial l}{\parcial O_t} \right)}^T \mathrm{d}W_{qh} \right)d L=t r(yo = 1∑Tht(∂O _t∂ l)Td W_q h)
因此:
∂ L ∂ W qh = ( ∑ yo = 1 T ht ( ∂ l ∂ O t ) T ) T = ∑ i = 1 T ∂ l ∂ O t ( ht ) T \frac{\parcial L}{\ parcial W_{qh}} = \left( \sum_{i=1}^{T}h_t{\left( \frac{\partial l}{\partial O_t} \right)}^T \right)^T = \sum_{i=1}^{T} \frac{\parcial l}{\parcial O_t} {\left( h_t \right)}^T∂ Wq h∂ L=(yo = 1∑Tht(∂O _t∂ l)t )T=yo = 1∑T∂O _t∂ l( ht)T
A continuación tratamos de resolverpara ∂ L ∂ W hx , ∂ L ∂ W hh \frac{\parcial L}{\parcial W_{hx}},\frac{\parcial L}{\parcial W_{hh}}∂ Wh x∂ L,∂ Wh h∂ L
Comience a resolver en el tiempo T (prod() aquí representa la regla de derivación de cadenas de matrices):
primero tenemos:
{ ht = W hxxt + W hhht − 1 O t = W qhht \left \{ \begin{array }{ll } h_t = W_{hx}x_t + W_{hh}h_{t-1} \\ O_t = W_{qh}h_t \\ \end{matriz} \right .{
ht=Wh xXt+Wh hht − 1Ot=Wq hht
∂ L ∂ h T = prod ( ∂ L ∂ OT , ∂ OT ∂ h T ) \frac{\parcial L}{\parcial h_T} = prod\left( \frac{\parcial L}{\parcial O_T}, \ frac{\parcial O_T}{\parcial h_T} \right)∂ horaT∂ L=producto _ _ _(∂O _T∂ L,∂ horaT∂O _T)
对于T-1时刻,有
∂ L ∂ h T − 1 = prod ( ∂ L ∂ OT − 1 , ∂ OT − 1 ∂ h T − 1 ) + prod ( ∂ L ∂ h T , ∂ h T ∂ h T − 1 ) \frac{\parcial L}{\parcial h_{T-1}} = prod\left( \frac{\parcial L}{\parcial O_{T-1}}, \frac{\parcial O_{ T-1}}{\parcial h_{T-1}} \right) + prod\left( \frac{\parcial L}{\parcial h_T}, \frac{\parcial h_T}{\parcial h_{T- 1}} \derecho)∂ horaT − 1∂ L=producto _ _ _(∂O _T − 1∂ L,∂ horaT − 1∂O _T − 1)+producto _ _ _(∂ horaT∂ L,∂ horaT − 1∂ horaT)
…
同理,对于t时刻, t < T,有:
∂ L ∂ ht = prod ( ∂ L ∂ O t , ∂ O t ∂ ht ) + prod ( ∂ L ∂ ht + 1 , ∂ ht + 1 ∂ ht ) \frac{\L parcial}{\h_t parcial} = prod\left( \frac{\L parcial}{\O_t parcial}, \frac{\O_t parcial}{\h_t parcial} \right) + prod\left ( \frac{\parcial L}{\parcial h_{t+1}}, \frac{\parcial h_{t+1}}{\parcial h_t} \right)∂ horat∂ L=producto _ _ _(∂O _t∂ L,∂ horat∂O _t)+producto _ _ _(∂ horat + 1∂ L,∂ horat∂ horat + 1)
para encontrar la derivada parcial como arriba para resolver∂ L ∂ W qh \frac{\partial L}{\partial W_{qh}}∂ Wq h∂ LComo se muestra en el método de derivación de cadena de trazas de matriz utilizado en , obtenemos:
∂ L ∂ ht = W qh T ∂ L ∂ O t + W hh T ∂ L ∂ ht + 1 \frac{\partial L}{\partial h_t } = W_{qh}^T \frac{\L parcial}{\O_t parcial} + W_{hh}^T \frac{\L parcial}{\h_{t+1} parcial∂ horat∂ L=Wqh _T∂O _t∂ L+Wh hT∂ horat + 1∂ L
打开该递归公式可得:
∂ L ∂ ht = ∑ yo = t T ( W hh T ) T − yo W qh T ∂ L ∂ OT + t − yo \frac{\parcial L}{\parcial h_t} = \ sum_{i=t}^T \left( W_{hh}^T \right)^{Ti} W_{qh}^T \frac{\parcial L}{\parcial O_{T+ti}}∂ horat∂ L=yo = t∑T( Wh hT)T - yoWqh _T∂O _T + t - yo∂ L
所以
∂ L ∂ W hx = prod ( ∂ L ∂ ht , ∂ ht ∂ W hx ) ∂ L ∂ W hh = prod ( ∂ L ∂ ht , ∂ ht ∂ W hh ) \frac{\parcial L}{\parcial W_ {hx}} = prod\left( \frac{\parcial L}{\parcial h_t}, \frac{\parcial h_t}{\parcial W_{hx}} \right) \\ \frac{\parcial L}{ \parcial W_{hh}} = prod\left( \frac{\parcial L}{\parcial h_t}, \frac{\parcial h_t}{\parcial W_{hh}} \right)∂ Wh x∂ L=producto _ _ _(∂ horat∂ L,∂ Wh x∂ horat)∂ Wh h∂ L=producto _ _ _(∂ horat∂ L,∂ Wh h∂ horat)
y luego (la regla de la cadena de producción aquí es la misma que la anterior, calcule usted mismo):
∂ L ∂ W hx = ∑ t = 1 T ∂ L ∂ htxt T ∂ L ∂ W hh = ∑ t = 1 T ∂ L ∂ htht − 1 T \frac{\parcial L}{\parcial W_{hx}} = \sum_{t=1}^T\frac{\parcial L}{\parcial h_t}x_t^T \\ \frac{ \parcial L }{\parcial W_{hh}} = \sum_{t=1}^T\frac{\parcial L}{\parcial h_t}h_{t-1}^T∂ Wh x∂ L=t = 1∑T∂ horat∂ LXtT∂ Wh h∂ L=t = 1∑T∂ horat∂ Lht - 1T
Más la solución anterior:
∂ L ∂ W qh = ∑ i = 1 T ∂ l ∂ O t ( ht ) T \frac{\partial L}{\partial W_{qh}} = \sum_{i=1} ^{ T} \frac{\parcial l}{\parcial O_t} {\left( h_t \right)}^T∂ Wq h∂ L=yo = 1∑T∂O _t∂ l( ht)Hasta
ahora, se ha deducido la retropropagación de RNN.
2. Derivación de retropropagación de LSTM
1. Descripción del problema
yo t = σ ( W xi X t + W hola H t − 1 + bi ) F t = σ ( W xf X t + W hf H t − 1 + bf ) O t = σ ( W xo X t + W ho H t − 1 + bo ) C t ′ = tanh ( W xc X t + W hc H t − 1 + bc ) C t = F t ⊙ C t − 1 + yo t ⊙ C t ′ H t = O t ⊙ tanh ( C t ) Y t = W qh H t + bq \begin{array}{ll} I_t=\sigma\left( W_{xi}X_t + W_{hi}H_{t-1} + b_i \right) \\ F_t=\sigma\left( W_{xf}X_t + W_{hf}H_{t-1} + b_f \right) \\ O_t=\sigma\left( W_{xo}X_t + W_{ho}H_ {t-1} + b_o \right) \\ C_t^{'}=\mathrm{tanh}\left( W_{xc}X_t + W_{hc}H_{t-1} + b_{c} \right) \\ C_t=F_t \odot C_{t-1} + I_t \odot C_t^{'} \\ H_t=O_t \odot \mathrm{tanh}(C_t) \\ Y_t=W_{qh}H_t + b_q \end {formación}It=pag( Wx yoXt+Whola _Ht − 1+byo)Ft=pag( Wx fXt+Wh fHt − 1+bf)Ot=pag( Wxo _Xt+Whola _Ht − 1+bo)Ct′′=nosotros n h _( Wx cXt+Wh cHt − 1+bdo)Ct=Ft⊙Ct − 1+It⊙Ct′′Ht=Ot⊙t a n h ( Ct)Yt=Wq hHt+bq