基于时间的后向传播和RNN变体(BPTT、RNNs)

Backpropagation Through Time, BPTT

Computational dependencies for RNNs model with three timesteps:

Vanilla RNNs model:
h t = W h x x t + W h h h t 1 , o t = W o h h t \pmb h_t=\pmb W_{hx}\pmb x_t+\pmb W_{hh}\pmb h_{t-1}, \quad\pmb o_t=\pmb W_{oh}\pmb h_t

Computing the total prediction error in T steps:
L ( x , y , W ) = t = 1 T l ( o t , y t ) L(\pmb x,\pmb y,\pmb W)=\sum_{t=1}^Tl(\pmb o_t,\pmb y_t)
Taking the derivatives with respect to W o h \pmb W_{oh} is fairly straightforward:
L W o h = 1 t T l o t h t \frac{\partial L}{\partial\pmb W_{oh}}=\sum_{1\leq t\leq T}\frac{\partial l}{\partial \pmb o_{t}} \cdot \pmb h_{t}
The dependency on W h x \pmb W_{hx} and W h h \pmb W_{hh} is a bit more tricky since it involves a chain of derivatives
L W h x = 1 t T l o t o t h t h t W h x L W h h = 1 t T l o t o t h t h t W h h \frac{\partial L}{\partial\pmb W_{hx}}= \sum_{1\leq t\leq T}\frac{\partial l}{\partial \pmb o_{t}}\frac{\partial\pmb o_{t}}{\partial\pmb h_t}\frac{\partial\pmb h_t}{\partial\pmb W_{hx}}\\[1ex] \frac{\partial L}{\partial\pmb W_{hh}}=\sum_{1\leq t\leq T}\frac{\partial l}{\partial \pmb o_{t}}\frac{\partial\pmb o_{t}}{\partial\pmb h_t}\frac{\partial\pmb h_t}{\partial\pmb W_{hh}}
After all, hidden states depend on each other and on past inputs:
h t + 1 h t = W h h       h T h t = ( W h h ) T t \frac{\partial\pmb h_{t+1}}{\partial\pmb h_t}=\pmb W_{hh}^\top \implies \frac{\partial\pmb h_{T}}{\partial\pmb h_t}=(\pmb W_{hh}^\top)^{T-t}
Chaining terms together yields:
h t W h x = x t + W h h h t 1 W h x = j = 1 t ( W h h ) t j x j h t W h h = h t 1 + W h h h t 1 W h h = j = 1 t ( W h h ) t j h j 1 \frac{\partial\pmb h_t}{\partial\pmb W_{hx}}=\pmb x_t+\pmb W_{hh}^\top\frac{\partial\pmb h_{t-1}}{\partial\pmb W_{hx}}=\sum_{j=1}^t(\pmb W_{hh}^\top)^{t-j}\pmb x_j\\ \frac{\partial\pmb h_t}{\partial\pmb W_{hh}}=\pmb h_{t-1}+\pmb W_{hh}^\top\frac{\partial\pmb h_{t-1}}{\partial\pmb W_{hh}}=\sum_{j=1}^t(\pmb W_{hh}^\top)^{t-j}\pmb h_{j-1}\\

The Gradient has a long term dependency on the matrix W h h \pmb W_{hh} .

有些场景下,RNNs模型仅使用最后一个状态的输出,此时 L = l ( o T , y T ) L=l(o_T,y_T)


Vanishing and Exploding Gradients in Vanilla RNNs

RNNs suffer from the problem of vanishing and exploding gradients, which hampers learning of long data sequences. For example, the simplified RNN that does not take any input x, and not only computes the recurrence on the hidden state (equivalently the input x could always be zero):

The gradient signal going backwards in the time through all the hidden states is always being multiplied by the same matrix (the recurrence matrix W h h \pmb W_{hh} ), interspersed with non-linearity backprop.

When you take one number a \pmb a and start multiplying it by some other number b \pmb b (i.e. a*b*b*b), this sequence either goes to zero if b < 1 |\pmb b| < 1 , or explodes to infinity when b > 1 |\pmb b|>1 . The same thing happens in the backward pass of an RNN, expect b \pmb b is a matrix not just a number.

If the gradient vanishes it means the earlier hidden states have no real effect on the later hidden states, meaning no long term dependencies are learned! If the gradient explodes it mean the later hidden states is bigger and is difficult to learn!

There are a few ways to combat the vanishing gradient problem. Proper initialization of the W matrix can reduce the effect of vanishing gradients. A more preferred solution is to use ReLU instead of tanh or sigmoid activation functions. The ReLU derivative is a constant of either 0 or 1, so it isn’t as likely to suffer from vanishing gradients. An even more popular solution is to use Long Short-Term Memory (LSTM) or Gated Recurrent Unit (GRU) architectures.


Long Short-Term Memory Networks, LSTMs

LSTM可解决RNN无法处理的长期依赖问题(梯度消失问题),通过三个Gate控制长期状态/记忆。

On timestep t t :

  • forget gate f t = σ ( W f [ h t 1 , x t ] + b f ) f_t=\sigma(W_f\cdot[h_{t-1},x_t]+b_f)
    • controls what parts of the previous cell state c t 1 c_{t-1} are written to cell state c t c_t .
  • input gate i t = σ ( W i [ h t 1 , x t ] + b i ) i_t=\sigma(W_i\cdot[h_{t-1},x_t]+b_i)
    • controls what parts of the new cell state are written to cell state c t c_t .
  • output gate o t = σ ( W o [ h t 1 , x t ] + b o ) o_t=\sigma(W_o\cdot[h_{t-1},x_t]+b_o)
    • controls what parts of cell state are output to hidden state.
  • new cell content c t = tanh ( W c [ h t 1 , x t ] + b c ) c'_t=\tanh(W_c\cdot [h_{t-1},x_t]+b_c)
    • new content to be written to the cell.
  • cell state c t = f t c t 1 + i t c t c_t=f_t\cdot c_{t-1}+i_t\cdot c'_t
    • erase (forget) some content from last cell state, and write (input) some new cell content.
  • hidden state h t = o t tanh ( c t ) h_t=o_t\cdot\tanh(c_t)
    • read (output) some content from the cell, the length is same as c t . c_t.

Preventing Vanishing Gradients with LSTMs

The biggest culprit in causing our gradients to vanish is that recursive derivative we need to compute: h T / h t \partial h_{T}/\partial h_t . If only this derivative was “well behaved” (that is, it doesn’t go to 0 or infinity as we back propagate through layers) then we could learn long term independencies!


The original LSTM solution

The original motivation behind the LSTM was to make this recursive derivative have a constant value, and then our gradients would neither explode or vanish.

The LSTM introduces a separate cell state c t c_t . In the original 1997 LSTM, the value for c t c_t depends on the previous value of the cell state and an update term weighted by the input gate value (see why):
c t = c t 1 + i c t c_t=c_{t-1} + ic'_t
This formulation doesn’t work well because the cell state tends to grow uncontrollably. In order to prevent this unbounded growth, a forget gate was added to scale the previous cell state, leading to the more modern formulation:
c t = f t c t 1 + i t c t c_t=f_tc_{t-1}+i_tc'_t


Looking at the full LSTM gradient

Let’s expand out the full derivation for c t / c t 1 \partial c_t/\partial c_{t-1} . First recall that in the LSTM, c t c_t is a function of f t f_t (the forget date), i t i_t (the input gate), and c t c'_t (the candidate cell state), each of these being a function of c t 1 c_{t-1} (since they are all functions of h t 1 h_{t-1} ). Via the multivariate chain rule we get:
c t c t 1 = c t f t f t h t 1 h t 1 c t 1 + c t i t i t h t 1 h t 1 c t 1 + c t c t c t h t 1 h t 1 c t 1 + f t \frac{\partial{c_t}}{\partial c_{t-1}}=\frac{\partial c_t}{\partial f_t}\frac{\partial f_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial c_{t-1}} + \frac{\partial c_t}{\partial i_t}\frac{\partial i_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial c_{t-1}} + \frac{\partial c_t}{\partial c'_t}\frac{\partial c'_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial c_{t-1}} + f_t
Now if we want to backpropagate back k time steps, we simply multiply terms in the form of the one above k times. Note the big difference between this recursive gradient and the one for vanilla RNNs.

In vanilla RNNs, the terms h t / h t 1 \partial h_{t}/\partial h_{t-1} will eventually take on a values that either always above 1 or always in the range [0, 1], this is essentially what leads to the vanishing/exploding gradient problem. The terms here, c t / c t 1 \partial c_t/\partial c_{t-1} , at any time step can take on either values that are greater than 1 or values in the range [0, 1]. Thus if we extend to an infinite amount of time steps, it is not guaranteed that we will end up converging to 0 or infinity (unlike in vanilla RNNs).

If we start to converge to zero, we can always set the values of f t f_t (lets say around 0.95 and other gate values) to be higher in order to bring the value of c t / c t 1 \partial c_t/\partial c_{t-1} closer to 1, thus preventing the gradients from vanishing (or at the very least, preventing them from vanishing too quickly). One important thing to note is that the values f t , o t , i t f_t,o_t,i_t and c t c'_t are things that the network learn to set. Thus, in this way the network learns to decide when to let the gradient vanish, and when to preserve it, by setting the gate values accordingly!

LSTM doesn’t guarantee that there is no vanishing/exploding gradient, but it does provide an easier way for the model to learn long-distance dependencies. This might all seem magical, but it really is just the result of two main things:

  • The additive update function for the cell state gives a derivative that’s much more ‘well behaved’;
  • The gating functions allow the network to decide how much the gradient vanishes, and can take on different values at each time step.

Gradient clipping: solution for exploding gradient

If the norm of the gradient is greater than some threshold, scale it down before applying SGD update

  • g ^ ϵ θ \hat{\pmb g}\leftarrow\dfrac{\partial\epsilon}{\partial\theta}
  • If g ^ threshold ||\hat{\pmb g}||\geq\text{threshold} then
    g ^ threshold g ^ g ^ \hat{\pmb g}\leftarrow \frac{\text{threshold}}{||\hat{\pmb g}||}\hat{\pmb g}
  • end if

This shows the loss surface a simple RNN (hidden state is scalar not a vector).


Gated Recurrent Units (GRU)

GRU as a simpler alternative to the LSTM. On each timestep t t , we have input x t x_t and hidden state h t h_t (no cell state).

On timestep t t :

  • update gate: z t = σ ( W z [ h t 1 , x t ] + b z ) z_t=\sigma(W_z\cdot[h_{t-1},x_t]+b_z)
    • controls what parts of hidden state are updated vs preserved.
  • reset gate: r t = σ ( W r [ h t 1 , x t ] + b r ) r_t=\sigma(W_r\cdot[h_{t-1},x_t]+b_r)
    • controls what parts of previous hidden state are used to compute new content.
  • new hidden state content: h t = tanh ( W h ~ [ r t h t 1 , x t ] + b h ~ ) h'_t=\tanh(W_{\tilde h}\cdot[r_t*h_{t-1}, x_t]+b_{\tilde h})
    • selects useful parts of previous hidden state, combining current input to compute new hidden state.
  • hidden state: h t = ( 1 z t ) h t 1 + z t h ~ t h_t=(1-z_t)*h_{t-1} + z_t*\tilde h_t
    • simultaneously controls what is kept from previous hidden state, and what is updated to new hidden state content.
    • z t z_t is setting the balance between preserving things from the previous hidden state versus writing new stuff.
    • z t z_t is set to zero, then we’re going to be keeping the hidden state the same on every step, in order to retain information over long distances.

LSTM vs GRU

The biggest difference is that GRU is quicker to compute and has fewer parameters. There is no conclusive that one consistently performs better than the other.

Rule of thumb: start with LSTM, but switch to GRU if you want something more efficient. LSTM is a good default choice, especially if your data has particularly long dependencies, or you have lots of training data, because LSTM has more parameters than GRU that can learn more complex dependencies.


Vanishing/Exploding Gradient Solutions

Vanishing/exploding gradients are a general problem, RNNs are particularly unstable due to the repeated multiplication by the same weight matrix, for all neural architectures (including feed-forward and convolutional), especially deep ones.

  • due to chain rule/choice of nonlinearity function, gradient can become vanishing small as it backpropagates;
  • thus lower layers are learnt very slowly (hard to train);
  • solution: add more direct connections (thus allowing the gradient to flow);

Residual connections (ResNet)


Dense connections (DenseNet)


Highway connections (HighwayNet)


Bidirectional RNNs

Contextual representation of word by concatenating forward and backward RNN. There two RNNs have separate weights.


Multi-layer RNNs

Multi-layer RNNs are powerful, but you might need skip/dense-connections if it’s deep, such as BERT.

单向多层RNN可从前向后或从下(input)向上(output)学习,但是双向多层RNN只能从下向上学习.


Reference:

1. Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass

猜你喜欢

转载自blog.csdn.net/sinat_34072381/article/details/105840491
今日推荐