retropropagación a través del tiempo

retropropagación a través del tiempo

Basado en el libro "Hands-On Deep Learning", este artículo ofrece una derivación relativamente detallada de los capítulos correspondientes.

1. Derivación de retropropagación de RNN

1. Descripción del problema

Esta es la expresión relacional de la red RNN en el tiempo t:
{ ht = W hxxt + W hhht − 1 O t = W qhht \left \{ \begin{array}{ll} h_t = W_{hx}x_t + W_{ hh} h_{t-1} \\ O_t = W_{qh}h_t \\ \end{matriz} \right .{ ht=Wh xXt+Wh hht 1Ot=Wq hht

Supongamos que la función de pérdida
L = 1 T ∑ t = 1 nl ( O t , yt ) L = \frac{1}{T}\sum_{t=1}^{n}l(O_t, y_t)L=T1t = 1nyo ( Ot,yt)

欲求
∂ L ∂ W qh , ∂ L ∂ W hx , ∂ L ∂ W hh \frac{\parcial L}{\parcial W_{qh}}, \frac{\parcial L}{\parcial W_{hx}}, \frac{\parcial L}{\parcial W_{hh}}Wq h L,Wh x L,Wh h L
Algunas preparaciones: Es necesario dominar la derivación de cadenas de matriz y las reglas y principios básicos de derivación .

2. Resolución de problemas

Primero , resuelve para ∂ L ∂ W qh \frac{\partial L}{\partial W_{qh}}Wq h L
Para cualquier momento ttt,显然有:
∂ L ∂ O t = 1 T ⋅ ∂ l ( O t , yt ) ∂ O tdl = tr ( ( ∂ l ∂ O t ) T ⋅ re O t ) O t = W qhht \frac{\ L parcial}{\O_t parcial} = \frac{1}{T} \cdot \frac{\l parcial(O_t, y_t)}{\O_t parcial} \\ \mathrm{d}l = tr\left( { \left( \frac{\parcial l}{\parcial O_t} \right)}^T \cdot \mathrm{d}O_t \right) \\ O_t = W_{qh}h{t}∂O _t L=T1∂O _tl ( Ot,yt)dl _=t r( (∂O _t l)Td Ot)Ot=Wq hh t
Por lo tanto,O t O_tOt带入微分式中,有:
re L = tr ( ∑ i = 1 T ( ∂ l ∂ O t ) T re W qh ⋅ ht ) \mathrm{d}L = tr\left( \sum_{i=1} ^{T}{\left( \frac{\parcial l}{\parcial O_t} \right)}^T \mathrm{d}W_{qh} \cdot h_t \right)d L=t r(yo = 1T(∂O _t l)Td W_q hht)
va aht h_thtPóngalo en el lado derecho de la traza:
d L = tr ( ∑ i = 1 T ht ( ∂ l ∂ O t ) T d W qh ) \mathrm{d}L = tr\left( \sum_{i=1 } ^{T}h_t{\left( \frac{\parcial l}{\parcial O_t} \right)}^T \mathrm{d}W_{qh} \right)d L=t r(yo = 1Tht(∂O _t l)Td W_q h)
因此:
∂ L ∂ W qh = ( ∑ yo = 1 T ht ( ∂ l ∂ O t ) T ) T = ∑ i = 1 T ∂ l ∂ O t ( ht ) T \frac{\parcial L}{\ parcial W_{qh}} = \left( \sum_{i=1}^{T}h_t{\left( \frac{\partial l}{\partial O_t} \right)}^T \right)^T = \sum_{i=1}^{T} \frac{\parcial l}{\parcial O_t} {\left( h_t \right)}^TWq h L=(yo = 1Tht(∂O _t l)t )T=yo = 1T∂O _t l( ht)T
A continuación tratamos de resolverpara ∂ L ∂ W hx , ∂ L ∂ W hh \frac{\parcial L}{\parcial W_{hx}},\frac{\parcial L}{\parcial W_{hh}}Wh x L,Wh h L
Comience a resolver en el tiempo T (prod() aquí representa la regla de derivación de cadenas de matrices):
primero tenemos:
{ ht = W hxxt + W hhht − 1 O t = W qhht \left \{ \begin{array }{ll } h_t = W_{hx}x_t + W_{hh}h_{t-1} \\ O_t = W_{qh}h_t \\ \end{matriz} \right .{ ht=Wh xXt+Wh hht 1Ot=Wq hht

∂ L ∂ h T = prod ( ∂ L ∂ OT , ∂ OT ∂ h T ) \frac{\parcial L}{\parcial h_T} = prod\left( \frac{\parcial L}{\parcial O_T}, \ frac{\parcial O_T}{\parcial h_T} \right)horaT L=producto _ _ _(∂O _T L,horaT∂O _T)
对于T-1时刻,有
∂ L ∂ h T − 1 = prod ( ∂ L ∂ OT − 1 , ∂ OT − 1 ∂ h T − 1 ) + prod ( ∂ L ∂ h T , ∂ h T ∂ h T − 1 ) \frac{\parcial L}{\parcial h_{T-1}} = prod\left( \frac{\parcial L}{\parcial O_{T-1}}, \frac{\parcial O_{ T-1}}{\parcial h_{T-1}} \right) + prod\left( \frac{\parcial L}{\parcial h_T}, \frac{\parcial h_T}{\parcial h_{T- 1}} \derecho)horaT 1 L=producto _ _ _(∂O _T 1 L,horaT 1∂O _T 1)+producto _ _ _(horaT L,horaT 1horaT)

同理,对于t时刻, t < T,有:
∂ L ∂ ht = prod ( ∂ L ∂ O t , ∂ O t ∂ ht ) + prod ( ∂ L ∂ ht + 1 , ∂ ht + 1 ∂ ht ) \frac{\L parcial}{\h_t parcial} = prod\left( \frac{\L parcial}{\O_t parcial}, \frac{\O_t parcial}{\h_t parcial} \right) + prod\left ( \frac{\parcial L}{\parcial h_{t+1}}, \frac{\parcial h_{t+1}}{\parcial h_t} \right)horat L=producto _ _ _(∂O _t L,horat∂O _t)+producto _ _ _(horat + 1 L,horathorat + 1)
para encontrar la derivada parcial como arriba para resolver∂ L ∂ W qh \frac{\partial L}{\partial W_{qh}}Wq h LComo se muestra en el método de derivación de cadena de trazas de matriz utilizado en , obtenemos:
∂ L ∂ ht = W qh T ∂ L ∂ O t + W hh T ∂ L ∂ ht + 1 \frac{\partial L}{\partial h_t } = W_{qh}^T \frac{\L parcial}{\O_t parcial} + W_{hh}^T \frac{\L parcial}{\h_{t+1} parcialhorat L=Wqh _T∂O _t L+Wh hThorat + 1 L
打开该递归公式可得:
∂ L ∂ ht = ∑ yo = t T ( W hh T ) T − yo W qh T ∂ L ∂ OT + t − yo \frac{\parcial L}{\parcial h_t} = \ sum_{i=t}^T \left( W_{hh}^T \right)^{Ti} W_{qh}^T \frac{\parcial L}{\parcial O_{T+ti}}horat L=yo = tT( Wh hT)T - yoWqh _T∂O _T + t - yo L
所以
∂ L ∂ W hx = prod ( ∂ L ∂ ht , ∂ ht ∂ W hx ) ∂ L ∂ W hh = prod ( ∂ L ∂ ht , ∂ ht ∂ W hh ) \frac{\parcial L}{\parcial W_ {hx}} = prod\left( \frac{\parcial L}{\parcial h_t}, \frac{\parcial h_t}{\parcial W_{hx}} \right) \\ \frac{\parcial L}{ \parcial W_{hh}} = prod\left( \frac{\parcial L}{\parcial h_t}, \frac{\parcial h_t}{\parcial W_{hh}} \right)Wh x L=producto _ _ _(horat L,Wh xhorat)Wh h L=producto _ _ _(horat L,Wh hhorat)
y luego (la regla de la cadena de producción aquí es la misma que la anterior, calcule usted mismo):
∂ L ∂ W hx = ∑ t = 1 T ∂ L ∂ htxt T ∂ L ∂ W hh = ∑ t = 1 T ∂ L ∂ htht − 1 T \frac{\parcial L}{\parcial W_{hx}} = \sum_{t=1}^T\frac{\parcial L}{\parcial h_t}x_t^T \\ \frac{ \parcial L }{\parcial W_{hh}} = \sum_{t=1}^T\frac{\parcial L}{\parcial h_t}h_{t-1}^TWh x L=t = 1Thorat LXtTWh h L=t = 1Thorat Lht - 1T
Más la solución anterior:
∂ L ∂ W qh = ∑ i = 1 T ∂ l ∂ O t ( ht ) T \frac{\partial L}{\partial W_{qh}} = \sum_{i=1} ^{ T} \frac{\parcial l}{\parcial O_t} {\left( h_t \right)}^TWq h L=yo = 1T∂O _t l( ht)Hasta
ahora, se ha deducido la retropropagación de RNN.

2. Derivación de retropropagación de LSTM

1. Descripción del problema

yo t = σ ( W xi X t + W hola H t − 1 + bi ) F t = σ ( W xf X t + W hf H t − 1 + bf ) O t = σ ( W xo X t + W ho H t − 1 + bo ) C t ′ = tanh ( W xc X t + W hc H t − 1 + bc ) C t = F t ⊙ C t − 1 + yo t ⊙ C t ′ H t = O t ⊙ tanh ( C t ) Y t = W qh H t + bq \begin{array}{ll} I_t=\sigma\left( W_{xi}X_t + W_{hi}H_{t-1} + b_i \right) \\ F_t=\sigma\left( W_{xf}X_t + W_{hf}H_{t-1} + b_f \right) \\ O_t=\sigma\left( W_{xo}X_t + W_{ho}H_ {t-1} + b_o \right) \\ C_t^{'}=\mathrm{tanh}\left( W_{xc}X_t + W_{hc}H_{t-1} + b_{c} \right) \\ C_t=F_t \odot C_{t-1} + I_t \odot C_t^{'} \\ H_t=O_t \odot \mathrm{tanh}(C_t) \\ Y_t=W_{qh}H_t + b_q \end {formación}It=pag( Wx yoXt+Whola _Ht 1+byo)Ft=pag( Wx fXt+Wh fHt 1+bf)Ot=pag( Wxo _Xt+Whola _Ht 1+bo)Ct′′=nosotros n h _( Wx cXt+Wh cHt 1+bdo)Ct=FtCt 1+ItCt′′Ht=Ott a n h ( Ct)Yt=Wq hHt+bq

2. Resolución de problemas

Supongo que te gusta

Origin blog.csdn.net/qq_41624557/article/details/117304654
Recomendado
Clasificación