Introduction to Deep Learning (59) Recurrent Neural Networks - Backpropagation through Time

foreword

The core content comes from blog link 1 blog link 2 I hope you can support the author a lot
This article is used for records to prevent forgetting

Recurrent Neural Networks - Backpropagation Through Time

Textbook

So far, we've repeatedly mentioned things like 梯度爆炸or 梯度消失, and the need for recurrent neural networks 分离梯度. For example, in the RNN implementation from scratch section, we called detachthe function on the sequence. In order to be able to quickly build a model and understand how it works, none of the concepts mentioned above have been fully explained. This section explores the details of backpropagation for sequence models in more depth, as well as the associated mathematics.

When we first implemented recurrent neural networks, we encountered 梯度爆炸problems. We found that gradient truncation is critical to ensure model convergence. In order to better understand this problem, this section will review how the gradient of the sequence model is calculated. There is nothing new in its working principle. After all, we are still using the chain rule to calculate the gradient.

We previously described forward and backpropagation and the associated computational graphs in a multilayer perceptron. Forward propagation in recurrent neural networks is relatively simple. Passing 时间反向传播(backpropagation through time,BPTT)is actually a specific application of the backpropagation technique in recurrent neural networks. It requires us to unroll the computational graph of a recurrent neural network one time step at a time to obtain dependencies between model variables and parameters. Then, based on the chain rule, backpropagation is applied to compute and store gradients. As sequences can be quite long, dependencies can also be quite long. For example, in a sequence of 1000 characters, the first token may have a significant effect on the last token. This is computationally infeasible (it requires too much time and memory), and requires over 1000 matrix products to get the very elusive gradient. This process is full of computational and statistical uncertainties. In the following, we clarify what happens and how to resolve them in practice.

1 Gradient Analysis of Recurrent Neural Network

We start with a simplified model of how a recurrent neural network works, which ignores the details of the properties of the hidden state and how it is updated. The mathematical representation here does not distinguish scalars, vectors, and matrices as clearly as in the past, because these details are not important for the analysis and instead only confuse the notation in this subsection.

In this simplified model, we set the time step ttThe hidden state of t is expressed asht h_tht, the input is expressed as xt x_txt, the output is expressed as ot o_tot. Recall from our discussion in the RNN section that the input and hidden state can be concatenated and multiplied by a weight variable in the hidden layer. Therefore, we use wh w_h respectivelywhwo w_owoto represent the weights of the hidden layer and the output layer. The hidden state and output of each time step can be written as:
ht = f ( xt , ht − 1 , wh ) , ot = g ( ht , wo ) , \begin{aligned}h_t &= f(x_t, h_{t -1}, w_h),\\o_t &= g(h_t, w_o),\end{aligned}htot=f(xt,ht1,wh),=g(ht,wo),

where fff andggg are the transformations of the hidden layer and the output layer, respectively. Thus, we have a chain{ … , ( xt − 1 , ht − 1 , ot − 1 ) , ( xt , ht , ot ) , … } \{\ldots, (x_{t-1}, h_{t- 1}, o_{t-1}), (x_{t}, h_{t}, o_t), \ldots\}{ ,(xt1,ht1,ot1),(xt,ht,ot),} , which depend on each other through loop calculations. Forward propagation is fairly simple, traversing the triplet( xt , ht , ot ) (x_t, h_t, o_t)(xt,ht,ot) , and then pass an objective function over allTTEvaluate the output ot o_t over T time stepsotand the corresponding label yt y_tyt之间的差异:
L ( x 1 , … , x T , y 1 , … , y T , w h , w o ) = 1 T ∑ t = 1 T l ( y t , o t ) . L(x_1, \ldots, x_T, y_1, \ldots, y_T, w_h, w_o) = \frac{1}{T}\sum_{t=1}^T l(y_t, o_t). L(x1,,xT,y1,,yT,wh,wo)=T1t=1Tl ( yt,ot) .
For backpropagation, the problem is a bit trickier, especially when we compute the objective functionLLL about the parameterwh w_hwh的梯度时。 具体来说,按照链式法则:
∂ L ∂ w h = 1 T ∑ t = 1 T ∂ l ( y t , o t ) ∂ w h = 1 T ∑ t = 1 T ∂ l ( y t , o t ) ∂ o t ∂ g ( h t , w o ) ∂ h t ∂ h t ∂ w h . \begin{aligned}\begin{aligned}\frac{\partial L}{\partial w_h} & = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial w_h} \\& = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial o_t} \frac{\partial g(h_t, w_o)}{\partial h_t} \frac{\partial h_t}{\partial w_h}.\end{aligned}\end{aligned} whL=T1t=1Twhl(yt,ot)=T1t=1Totl(yt,ot)htg(ht,wo)whht.
The first and second terms of the product in the above formula are easy to calculate, while the third term ∂ ht / ∂ wh \partial h_t/\partial w_hht/whis where things get tricky, because we need to recursively compute the parameter wh w_hwhto ht h_tht的影响。 根据 h t = f ( x t , h t − 1 , w h ) , o t = g ( h t , w o ) , \begin{aligned}h_t &= f(x_t, h_{t-1}, w_h),\\o_t &= g(h_t, w_o),\end{aligned} htot=f(xt,ht1,wh),=g(ht,wo),Recursive computation in ht h_thtBoth depends on ht − 1 h_{t-1}ht1which in turn depends on wh w_hwh, where ht − 1 h_{t-1}ht1The calculation of also depends on wh w_hwh。 因此,使用链式法则产生:
∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . \frac{\partial h_t}{\partial w_h}= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}. whht=whf(xt,ht1,wh)+ht1f(xt,ht1,wh)whht1.In
order to derive the above gradient, suppose we have three sequences{ at } , { bt } , { ct } \{a_{t}\},\{b_{t}\},\{c_{t}\}{ at},{ bt},{ ct}, 当 t = 1 , 2 , … t=1,2,\ldots t=1,2,... , the sequence satisfiesa 0 = 0 a_{0}=0a0=0 a t = b t + c t a t − 1 a_{t}=b_{t}+c_{t}a_{t-1} at=bt+ctat1. for t ≥ 1 t\geq 1t1,就很容易得出:
a t = b t + ∑ i = 1 t − 1 ( ∏ j = i + 1 t c j ) b i . a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}. at=bt+i=1t1(j=i+1tcj)bi.replace at a_t
based on the following formulaat b t b_t btand ct c_tct
a t = ∂ h t ∂ w h , b t = ∂ f ( x t , h t − 1 , w h ) ∂ w h , c t = ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 , \begin{aligned}\begin{aligned}a_t &= \frac{\partial h_t}{\partial w_h},\\ b_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}, \\ c_t &= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}},\end{aligned}\end{aligned} atbtct=whht,=whf(xt,ht1,wh),=ht1f(xt,ht1,wh),

公式 ∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . \frac{\partial h_t}{\partial w_h}= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}. whht=whf(xt,ht1,wh)+ht1f(xt,ht1,wh)whht1The gradient calculation in . satisfies at = bt + ctat − 1 a_{t}=b_{t}+c_{t}a_{t-1}at=bt+ctat1。 因此,对于每个 a t = b t + ∑ i = 1 t − 1 ( ∏ j = i + 1 t c j ) b i . a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}. at=bt+i=1t1(j=i+1tcj)bi. , we can remove∂ ht ∂ wh = ∂ f ( xt , ht − 1 , wh ) ∂ wh + ∂ f ( xt , ht − 1 , wh ) ∂ ht − 1 ∂ ht − 1 ∂ wh . \frac{\partial h_t}{\partial w_h}= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\frac{\partial f(x_ {t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}.whht=whf(xt,ht1,wh)+ht1f(xt,ht1,wh)whht1.中的循环计算
∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∑ i = 1 t − 1 ( ∏ j = i + 1 t ∂ f ( x j , h j − 1 , w h ) ∂ h j − 1 ) ∂ f ( x i , h i − 1 , w h ) ∂ w h . \frac{\partial h_t}{\partial w_h}=\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f(x_{j},h_{j-1},w_h)}{\partial h_{j-1}} \right) \frac{\partial f(x_{i},h_{i-1},w_h)}{\partial w_h}. whht=whf(xt,ht1,wh)+i=1t1(j=i+1thj1f(xj,hj1,wh))whf(xi,hi1,wh).While
we can recursively compute∂ ht / ∂ wh \partial h_t/\partial w_hht/wh, but the chain becomes very long when $t is large. We need to find a way to deal with this problem.

1.1 Complete Computation

显然,我们可以仅仅计算 ∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∑ i = 1 t − 1 ( ∏ j = i + 1 t ∂ f ( x j , h j − 1 , w h ) ∂ h j − 1 ) ∂ f ( x i , h i − 1 , w h ) ∂ w h . \frac{\partial h_t}{\partial w_h}=\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f(x_{j},h_{j-1},w_h)}{\partial h_{j-1}} \right) \frac{\partial f(x_{i},h_{i-1},w_h)}{\partial w_h}. whht=whf(xt,ht1,wh)+i=1t1(j=i+1thj1f(xj,hj1,wh))whf(xi,hi1,wh). However, such computations are very slow and gradient explosions may occur, since small changes in the initial conditions may have a large impact on the results. That is, we can observe a phenomenon similar to the butterfly effect, where small changes in initial conditions lead to disproportionate changes in outcomes. This is highly undesirable for the model we want to estimate. After all, what we are looking for is an estimator that generalizes well to models with high stability. Therefore, in practice, this method is almost never used.

1.2 Truncation time step

Alternatively, we can use τ \tauτ步后截断 ∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∑ i = 1 t − 1 ( ∏ j = i + 1 t ∂ f ( x j , h j − 1 , w h ) ∂ h j − 1 ) ∂ f ( x i , h i − 1 , w h ) ∂ w h . \frac{\partial h_t}{\partial w_h}=\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f(x_{j},h_{j-1},w_h)}{\partial h_{j-1}} \right) \frac{\partial f(x_{i},h_{i-1},w_h)}{\partial w_h}. whht=whf(xt,ht1,wh)+i=1t1(j=i+1thj1f(xj,hj1,wh))whf(xi,hi1,wh)The sum calculation in . This is what we've been discussing so far, e.g. when separating gradients in the RNN implementation from zero section. This leads to the true gradient近似, just terminate the summation as∂ ht − τ / ∂ wh \partial h_{t-\tau}/\partial w_hht τ/wh. In practice, this approach works pretty well. It is often called truncated backpropagation through time. Doing so resulted in the model focusing primarily on short-term effects rather than long-term effects. This is desirable in reality because it biases the estimates towards simpler and more stable models

1.3 Random truncation

Finally, we can replace ∂ ht / ∂ wh \partial h_t/\partial w_h with a random variableht/wh, the random variable is correct as expected, but truncates the sequence. This random variable is obtained by using the sequence ξ t \xi_tXtTo achieve, the sequence is predefined 0 ≤ π t ≤ 1 0 \leq \pi_t \leq 10Pit1 , whereP ( ξ t = 0 ) = 1 − π t P(\xi_t = 0) = 1-\pi_tP ( xt=0)=1PitP ( ξ t = π t − 1 ) = π t P(\xi_t = \pi_t^{-1}) = \pi_tP ( xt=Pit1)=Pit, so E [ξ t ] = 1 E[\xi_t] = 1E [ xt]=1。 我们使用它来替换 ∂ h t ∂ w h = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . \frac{\partial h_t}{\partial w_h}= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}. whht=whf(xt,ht1,wh)+ht1f(xt,ht1,wh)whht1The gradient in∂ ht / ∂ wh \partial h_t/\partial w_hht/wh得到:
z t = ∂ f ( x t , h t − 1 , w h ) ∂ w h + ξ t ∂ f ( x t , h t − 1 , w h ) ∂ h t − 1 ∂ h t − 1 ∂ w h . z_t= \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial w_h} +\xi_t \frac{\partial f(x_{t},h_{t-1},w_h)}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_h}. zt=whf(xt,ht1,wh)+Xtht1f(xt,ht1,wh)whht1.from ξ t \
xi_tXtDerived from the definition of E [ zt ] = ∂ ht / ∂ wh E[z_t] = \partial h_t/\partial w_hE[zt]=ht/wh. Whenever ξ t = 0 \xi_t = 0Xt=0 , the recursive calculation terminates at thisttt time step. This results in a weighted sum of sequences of different lengths, where long sequences occur infrequently and so will be weighted appropriately.

1.4 Comparison Strategies

insert image description here
Comparing the strategy of calculating the gradient in RNN, the 3 lines from top to bottom are: random truncation, regular truncation, and complete calculation

The figure above illustrates three strategies for analyzing the first few characters in The Time Machine when using backpropagation through time based on recurrent neural networks:

  • The first line is randomly truncated by dividing the text into fragments of different lengths;

  • The second line applies regular truncation by breaking the text into subsequences of equal length. This is what we have been doing in our recurrent neural network experiments;

  • The third line employs full backpropagation through time, resulting in expressions that are computationally infeasible.

Unfortunately, while random truncation is attractive in theory, it is probably no better in practice than regular truncation due to a variety of factors. First, the observations, after backpropagating several time steps in the past, are sufficient to capture the actual dependencies. Second, the increased variance offsets the fact that gradients are more accurate with more time steps. Third, what we really want are models with only short-range interactions. Therefore, what the model needs is the mild regularization effect of the truncated backpropagation through time method.

2 Details of backpropagation through time

After discussing the general principles, we look at the details of the backpropagation through time problem. Unlike the analysis above, below we show how to compute the gradients of the objective function with respect to all decomposition model parameters. To keep it simple, we consider a recurrent neural network with no bias parameters whose activation function in the hidden layer uses the identity map ϕ ( x ) = x \phi(x)=xϕ ( x )=x . For time stepttt , let the input of a single sample and its corresponding label bext ∈ R d \mathbf{x}_t \in \mathbb{R}^dxtRd y t y_t yt. Calculate the hidden state ht ∈ R h \mathbf{h}_t \in \mathbb{R}^hhtRh and outputot ∈ R q \mathbf{o}_t \in \mathbb{R}^qotRqState :
ht = W hxxt + W hhht − 1 , ot = W qhht , \begin{aligned}\mathbf{h}_t &= \mathbf{W}_{hx}\mathbf{x}_t + \ mathbf{W}_{hh} \mathbf{h}_{t-1},\\\mathbf{o}_t &= \mathbf{W}_{qh}\mathbf{h}_{t},\ end{aligned}htot=Whxxt+Whhht1,=Wqhht,
Where the weight parameter is W hx ∈ R h × d \mathbf{W}_{hx} \in \mathbb{R}^{h \times d}WhxRh×d W h h ∈ R h × h \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} WhhRh×h W q h ∈ R q × h \mathbf{W}_{qh} \in \mathbb{R}^{q \times h} WqhRq × h。 用l ( ot , yt ) l(\mathbf{o}_t, y_t)l(ot,yt) represents the time stepttat t (i.e. beyond TTfrom the beginning of the sequenceT time steps), then the overall loss of our objective function is:
L = 1 T ∑ t = 1 T l ( ot , yt ) . L = \frac{1}{T} \sum_{t= 1}^T l(\mathbf{o}_t, y_t).L=T1t=1Tl(ot,yt).

In order to visualize the dependencies between model variables and parameters during the calculation of the recurrent neural network, we can draw a calculation graph for the model, as shown in the figure below. For example, the hidden state h 3 \mathbf{h}_3 of time step 3h3The calculation depends on the model parameters W hx \mathbf{W}_{hx}WhxW hh \mathbf{W}_{hh}Whh, and the hidden state h 2 \mathbf{h}_2 of the final time steph2And the input x 3 \mathbf{x}_3 of the current time stepx3

insert image description here
The figure above represents a computational graph of the dependencies of a recurrent neural network model with three time steps. Uncolored boxes represent variables, colored boxes represent parameters, and circles represent operators

As just said, the model parameters in the above figure are W hx \mathbf{W}_{hx}WhxW hh \mathbf{W}_{hh}WhhW qh \mathbf{W}_{qh}Wqh. Typically, training this model requires gradient calculations for these parameters: ∂ L / ∂ W hx \partial L/\partial \mathbf{W}_{hx}L/Whx ∂ L / ∂ W h h \partial L/\partial \mathbf{W}_{hh} L/Whh ∂ L / ∂ W q h \partial L/\partial \mathbf{W}_{qh} L/Wqh. According to the dependencies in the above figure, we can traverse the calculation graph in the opposite direction of the arrow, and calculate and store gradients in turn. In order to flexibly express the multiplication of matrices, vectors and scalars of different shapes in the chain rule, we continue to use the prod \text{prod}prod operator.

First, at any time step ttt , the objective function to differentiate with respect to the model output is fairly straightforward:

∂ L ∂ o t = ∂ l ( o t , y t ) T ⋅ ∂ o t ∈ R q . \frac{\partial L}{\partial \mathbf{o}_t} = \frac{\partial l (\mathbf{o}_t, y_t)}{T \cdot \partial \mathbf{o}_t} \in \mathbb{R}^q. otL=Totl(ot,yt)Rq.

Now, we can calculate the objective function with respect to the parameter W qh \mathbf{W}_{qh} in the output layerWqh的梯度: ∂ L / ∂ W q h ∈ R q × h \partial L/\partial \mathbf{W}_{qh} \in \mathbb{R}^{q \times h} L/WqhRq × h . Based on the above figure, the objective functionLLL通过 o 1 , … , o T \mathbf{o}_1, \ldots, \mathbf{o}_T o1,,oTdepends on W qh \mathbf{W}_{qh}Wqh。 依据链式法则,得到
∂ L ∂ W q h = ∑ t = 1 T prod ( ∂ L ∂ o t , ∂ o t ∂ W q h ) = ∑ t = 1 T ∂ L ∂ o t h t ⊤ , \frac{\partial L}{\partial \mathbf{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top, WqhL=t=1Tprod(otL,Wqhot)=t=1TotLht,
and next, as shown, at the final time stepTTT , the objective functionLLL only passeso T \mathbf{o}_ToTDepends on the hidden state h T \mathbf{h}_ThT. Therefore, we can easily get the gradient ∂ L / ∂ h T ∈ R h \partial L/\partial \mathbf{h}_T \in \mathbb{R}^h by using the chain methodL/hTRh
∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T . \frac{\partial L}{\partial \mathbf{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_T}, \frac{\partial \mathbf{o}_T}{\partial \mathbf{h}_T} \right) = \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_T}. hTL=prod(oTL,hToT)=WqhoTL.When
the objective functionLLL byht + 1 \mathbf{h}_{t+1}ht+1ot \mathbf{o}_totdepends on ht \mathbf{h}_tht, for any time step t < T t < Tt<T is getting trickier. According to the chain rule, the gradient of the hidden state∂ L / ∂ ht ∈ R h \partial L/\partial \mathbf{h}_t \in \mathbb{R}^hL/htRh at any time stept < T t\lt Tt<T时都可以递归地计算为:
∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t . \frac{\partial L}{\partial \mathbf{h}_t} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} \right) + \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{h}_t} \right) = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{t+1}} + \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_t}. htL=prod(ht+1L,htht+1)+prod(otL,htot)=Whhht+1L+WqhotL.

For analysis, for any time step 1 ≤ t ≤ T 1 \leq t \leq T1tT展开递归计算得
∂ L ∂ h t = ∑ i = t T ( W h h ⊤ ) T − i W q h ⊤ ∂ L ∂ o T + t − i . \frac{\partial L}{\partial \mathbf{h}_t}= \sum_{i=t}^T {\left(\mathbf{W}_{hh}^\top\right)}^{T-i} \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_{T+t-i}}. htL=i=tT(Whh)TiWqhoT+tiL.

We can see from the above formula that this simple linear example already exhibits some key problems of long sequence models: it gets stuck in W hh ⊤ \mathbf{W}_{hh}^\topWhhA potentially very large power of . In this power, eigenvalues ​​less than 1 will disappear, and eigenvalues ​​greater than 1 will diverge. This is numerically unstable in the form of vanishing or exploding gradients. One way to solve this problem is to truncate the size of the time step (as mentioned in the first part) according to the computational convenience. In practice, this truncation is achieved by separating the gradient after a given number of time steps. Later, we will learn how more complex sequence models such as long short-term memory models further alleviate this problem.

Finally, the above figure shows that: the objective function LLL passes through the hidden stateh 1 , … , h T \mathbf{h}_1, \ldots, \mathbf{h}_Th1,,hTDepends on the model parameters W hx \mathbf{W}_{hx} in the hidden layerWhxW hh \mathbf{W}_{hh}Whh. To compute the gradient ∂ L / ∂ W hx ∈ R h × d with respect to these parameters \partial L / \partial \mathbf{W}_{hx} \in \mathbb{R}^{h \times d}L/WhxRh×d ∂ L / ∂ W h h ∈ R h × h \partial L / \partial \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} L/WhhRh×h, 我们应用链式规则得:
∂ L ∂ W h x = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h x ) = ∑ t = 1 T ∂ L ∂ h t x t ⊤ , ∂ L ∂ W h h = ∑ t = 1 T prod ( ∂ L ∂ h t , ∂ h t ∂ W h h ) = ∑ t = 1 T ∂ L ∂ h t h t − 1 ⊤ , \begin{aligned}\begin{aligned} \frac{\partial L}{\partial \mathbf{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{x}_t^\top,\\ \frac{\partial L}{\partial \mathbf{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{h}_{t-1}^\top, \end{aligned}\end{aligned} WhxLWhhL=t=1Tprod(htL,Whxht)=t=1ThtLxt,=t=1Tprod(htL,Whhht)=t=1ThtLht1,

其中 ∂ L / ∂ h t \partial L/\partial \mathbf{h}_t L/ht是由 ∂ L ∂ h T = prod ( ∂ L ∂ o T , ∂ o T ∂ h T ) = W q h ⊤ ∂ L ∂ o T . \frac{\partial L}{\partial \mathbf{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_T}, \frac{\partial \mathbf{o}_T}{\partial \mathbf{h}_T} \right) = \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_T}. hTL=prod(oTL,hToT)=WqhoTL. ∂ L ∂ h t = prod ( ∂ L ∂ h t + 1 , ∂ h t + 1 ∂ h t ) + prod ( ∂ L ∂ o t , ∂ o t ∂ h t ) = W h h ⊤ ∂ L ∂ h t + 1 + W q h ⊤ ∂ L ∂ o t . \frac{\partial L}{\partial \mathbf{h}_t} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} \right) + \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{h}_t} \right) = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{t+1}} + \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_t}. htL=prod(ht+1L,htht+1)+prod(otL,htot)=Whhht+1L+WqhotL. Obtained by recursive calculation, is a key quantity affecting numerical stability.

As we explained earlier, since backpropagation through time is how backpropagation is applied in recurrent neural networks, training recurrent neural networks uses forward propagation and backpropagation through time alternately. The above gradients are sequentially computed and stored through time backpropagation. Specifically, stored intermediate values ​​are reused to avoid double calculations, e.g. storing ∂ L / ∂ ht \partial L/\partial \mathbf{h}_tL/ht, so that when calculating ∂ L / ∂ W hx \partial L / \partial \mathbf{W}_{hx}L/Whx ∂ L / ∂ W h h \partial L / \partial \mathbf{W}_{hh} L/Whhused when.

3 Summary

  • "Backpropagation through time" only applies to backpropagation on sequence models with hidden states.

  • The truncation is a requirement for computational convenience and numerical stability. Truncation includes: regular truncation and random truncation.

  • The high power of the matrix may cause the divergence or disappearance of the eigenvalues ​​of the neural network, which will be expressed in the form of gradient explosion or gradient disappearance.

  • For computational efficiency, Backpropagation Through Time caches intermediate values ​​during computation.

Guess you like

Origin blog.csdn.net/qq_52358603/article/details/128275572