Numerical stability of deep learning models - explanation of gradient decay and gradient explosion

0. Preface

In accordance with international practice, I would like to first declare: This article is only my own understanding of learning. Although I refer to the valuable insights of others, the content may contain inaccuracies. If you find errors in the article, I hope to criticize and correct them so that we can make progress together.

The purpose of this article is to explain common problems regarding numerical stability in deep learning network models: gradient decay (vanishing) and explosion (explosion), as well as common solutions.

Part of the content, opinions and illustrations of this article are based on the lectures of the School of Computer Science at the University of Toronto - Lecture 15: Exploding and Vanishing Gradients , as well as Chapter 3.15 "Numerical Stability and Model Initialization" of Dive into deep learning.

1. Why do gradient decay and gradient explosion occur?

Let’s use the simplified fully connected neuron network below to explain. This fully connected neuron network has only one neuron in each layer and can be regarded as a network connected by a series of neurons.

insert image description here

In forward propagation, since the transmission of numerical values ​​needs to go through the nonlinear activation function σ ( ) \sigma()σ ( ) (such as Sigmoid, Tanh function), its numerical size is limited, so前向传播一般不存在数值稳定性的问题.

In backpropagation, for example, solving for the output yyy pair weightw 1 w_1w1的偏导为:
∂ y ∂ w 1 = σ ′ ( z n ) w n ⋅ σ ′ ( z n − 1 ) w n − 1 ⋅ ⋅ ⋅ σ ′ ( z 1 ) x \frac{\partial y}{\partial w_1}=\sigma'(z_n)w_n · \sigma'(z_{n-1})w_{n-1} ··· \sigma'(z_{1})x w1y=p(zn)wnp(zn1)wn1⋅⋅⋅p(z1)x
z n = { w n ⋅ h n − 1 + b n , n > 1 w 1 ⋅ x + b 1 , n = 1 z_n= \left \{\begin{array}{cc} w_n·h_{n-1}+b_n, & n>1\\ w_1·x+b_1, & n=1 \end{array} \right. zn={ wnhn1+bn,w1x+b1,n>1n=1
It can be seen here that if the weight wn w_nwnThe initial choice of is unreasonable, or wn w_nwnDuring the gradual optimization process, it appears that σ ′ ( zn ) wn \sigma'(z_n)w_np(zn)wnIf most or all are greater than 1 or less than 1, and the network is deep enough, it will cause numerical instability in the partial derivative of back propagation - gradient decay or gradient explosion.

To simplify the understanding, assume σ ′ ( zn ) wn = 0.8 \sigma'(z_n)w_n=0.8p(zn)wn=0.8 , there are 50 layers of network depth,0. 8 50 = 0.000014 0.8^{50}=0.0000140.850=0.000014;假设σ ′ (zn) wn = 1.2 \sigma'(z_n)w_n=1.2p(zn)wn=1.2 , there are 50 layers of network depth,1. 2 50 = 9100 1.2^{50}=91001.250=9100

Another way to explain numerical stability, referring to Lecture 15: Exploding and Vanishing Gradients, is that deep learning networks are analogous to the iterative use of nonlinear equations, such as f ( x ) = 3.5 x ( 1 − x ) f(x)= 3.5x(1-x)f(x)=3.5 x ( 1x ) after multiple iterationsy = f ( f ( ⋅ ⋅ ⋅ f ( x ) ) ) y=f(f(···f(x)))y=f(f(⋅⋅⋅The situation after f (
insert image description here
_ (Corresponding to gradient explosion).

We should also notice that after 6 iterations ∂ y ∂ x ≈ 0 \frac{\partial y}{\partial x}≈0xyThe area of ​​0 (corresponding to gradient decay).

2. How to improve numerical stability?

2.1 Random initialization of model parameters

This is the simplest and most commonly used method to combat gradient decay and gradient explosion. As explained above: σ ′ ( zn ) wn \sigma'(z_n)w_np(zn)wnIf most or all are greater than 1 or less than 1, and the network is deep enough, numerical instability may easily occur. If you initialize the model parameters randomly, it will largely avoid the problem of wn w_nwnUnreasonable initial selection leads to gradient decay or explosion.

Xavier random initialization is a commonly used method: assuming that the number of inputs to a hidden layer is aaa , the number of outputs isbbb , Xavier random initialization will randomly sample the weight parameters in this layer from( − 6 a + b , 6 a + b ) (-\sqrt{\frac{6}{a+b}},\sqrt{\frac {6}{a+b}})(a+b6 ,a+b6 )

2.2 Gradient Clipping

This is a method of artificially limiting the gradient that is too large or too small. The idea is to give the original gradient ggg plus a coefficient, inggWhen the absolute value of g is too large, reduce it and vice versa. This coefficient is:
η ∣ ∣ g ∣ ∣ \frac{\eta}{||g||}∣∣g∣∣h

Among them η \etaeta is a hyperparameter,∣ ∣ g ∣ ||g||∣∣ g ∣∣ is the second norm of the gradient.

Although increasing this coefficient will result in the result not being the true partial derivative of the loss function with respect to the weight, it can maintain numerical stability.

2.3 Regularization

This is a way to suppress exploding gradients. I have introduced the regularization method before: actual weight attenuation based on PyTorch - L2 norm regularization method (with code) . The idea is to add the norm of the weight as a penalty term in the loss function:
loss = 1 n Σ ( y − y ^ ) 2 + λ 2 n ∣ ∣ w ∣ ∣ 2 loss = \dfrac{1}{n} \Sigma (y - \widehat{y})^2+ \dfrac{\lambda}{2n}||w ||^2loss=n1S ( yy )2+2n _l∣∣w2In
the process of continuous iteration (learning) of the deep learning model,loss lossl As oss gets smaller and smaller, the norm of the weight also gets smaller and smaller, which suppresses gradient explosion.

2.4 Batch Normalization

Batch Normalization is a data standardization method that adds scaling and shifting based on Normalization. For its specific working principle, please refer to: Description of Batch Normalization .

The basic principle that Batch Normalization can maintain numerical stability is similar to gradient clipping: both artificially increase the scaling of the numerical value and maintain the numerical value within a reasonable range that is neither too large nor too small. The difference between the two is that gradient clipping directly acts on the partial derivative of the loss function with respect to the weight in the back propagation process; while Batch Normalization standardizes the output of a certain layer in the forward propagation process and indirectly maintains the stability of the partial derivative of the weight. .

What needs to be pointed out here is: due to input xxx also participates in the calculation of partial derivatives, ifxxx is a high-dimensional vector, then for the inputxxBatch Normalization processing of x is also necessary.

2.5 LSTM?Short Cut!

Many articles indicate that LSTM (long short-term memory) networks can help maintain numerical stability. I was puzzled when I first saw these articles - because we need a general method to improve the numerical stability of existing models, and It is not a direct replacement with the LSTM network model. Moreover, LSTM is not a universal deep learning model. It is impossible to replace the model with LSTM when encountering gradient decay or gradient explosion.

If you don’t know what LSTM is, you can read this: Algorithm introduction and mathematical derivation of LSTM (long short-term memory) network

Later, I saw Lecture 15: Exploding and Vanishing Gradients and understood the misunderstanding: this article uses RNN as an example to illustrate numerical stability. For RNN, LSTM is indeed an improved model, because its internal "gate" structure that maintains "long-term memory" really helps improve numerical stability.

I think most articles that single out LSTM to show that it can improve numerical stability have misunderstood.

The structure of Short Cut is a universal rule to improve numerical stability, and LSTM is just a special case of improving RNN.
insert image description here

For the specific mechanism of Short Cut, please refer to He Kaiming’s original text: Deep Residual Learning for Image Recognition

Guess you like

Origin blog.csdn.net/m0_49963403/article/details/132394707