The principle of Batch Normalization and gradient disappearance and gradient explosion

Analysis of Batch Normalization Principle

foreword

This article is some Batch Normalization-related materials compiled by myself with reference to some books and blogs. The whole article is compiled based on my own understanding for future reference. References are posted after the text.

Batch Normalization can be used to solve the problem of gradient disappearance and gradient explosion, including the internal covariate shift (Internal Covariate Shift) mentioned in the original paper, so this article first sorts out some gradient disappearance and gradient explosion and internal covariance shift. Principle, and then analyze the principle of Batch Normalization.

1.1 Gradient disappearance and gradient explosion

In some papers (such as the one on resnet) and technical books, Batch Normalization is mentioned to be used to solve gradient disappearance and gradient explosion. Here, refer to the book "Pytorch in simple terms" to give the principle of gradient disappearance and gradient explosion. .

insert image description here
where hj \mathbf{h}_{j}hjfor the jjthThe input of neurons in layer j , W j \mathbf{W}_jWjfor the jjthThe weight of neurons in layer j , and hj + 1 \mathbf{h}_{j+1}hj+1For the output of this layer, that is, as the input of the lower layer, theoretically hj + 1 = W jhj \mathbf{h}_{j+1}=\mathbf{W}_j\mathbf{h}_{j}hj+1=Wjhj, After adding the activation function hj + 1 = fj ( W jhj ) \mathbf{h}_{j+1} = f_j(\mathbf{W}_j\mathbf{h}_j )hj+1=fj(Wjhj)

According to the chain rule in calculus, f ( x ) f(\mathbf{x})f(x) x \mathbf{x} x的求导为:
∂ f ( x ) ∂ x = ∂ f ( x ) ∂ y ∂ y ∂ x \frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}= \frac{\partial f(\mathbf{x})}{\partial \mathbf{y}} \frac{\partial \mathbf{y}}{\partial \mathbf{x}} xf(x)=yf(x)xy
We assume that the final loss function is L = fn ( hn ) L = f_n(\mathbf{h}_n)L=fn(hn) , is the function of neurons in the output layer, derivatives on both sides, according to the chain rule:
∂ L ∂ W j = ∂ L ∂ hj + 1 ∂ hj + 1 ∂ W j = ( ∂ L ∂ hj + 1 ⊙ ∂ fj ( W jhj ) ∂ W jhj ) hj T \frac{\partial L}{\partial \mathbf{W}_{j}} = \frac{\partial L}{\partial \mathbf{h}_{j+ 1}} \frac{\partial \mathbf{h}_{j+1}}{\partial \mathbf{W}_{j}} = \left(\frac{\partial L}{\partial \mathbf{ h}_{j+1}} \odot \frac{\partial f_j(\mathbf{W}_j\mathbf{h}_j )}{\partial \mathbf{W}_j\mathbf{h}_j} \right )\mathbf{h}_j^TWjL=hj+1LWjhj+1=(hj+1LWjhjfj(Wjhj))hjT
∂ L ∂ hj = ∂ L ∂ hj + 1 ∂ hj + 1 ∂ hj = W j T ( ∂ L ∂ hj + 1 ⊙ ∂ fj ( W jhj ) ∂ W jhj ) \frac{\partial L}{\partial \ mathbf{h}_{j}} = \frac{\partial L}{\partial \mathbf{h}_{j+1}} \frac{\partial \mathbf{h}_{j+1}}{ \partial \mathbf{h}_{j}} = \mathbf{W}_j^T \left(\frac{\partial L}{\partial \mathbf{h}_{j+1}} \odot \frac {\partial f_j(\mathbf{W}_j\mathbf{h}_j )}{\partial \mathbf{W}_j\mathbf{h}_j} \right)hjL=hj+1Lhjhj+1=WjT(hj+1LWjhjfj(Wjhj))
amongthem∂ L ∂ hj \frac{\partial L}{\partial \mathbf{h}_{j}}hjLThe formula can be regarded as the derivative of the loss function to the data, that is, the data gradient; and ∂ L ∂ W j \frac{\partial L}{\partial \mathbf{W}_{j}}WjLis the derivative of the loss function to the weight, that is, the weight gradient. It can be seen from formula 2 that the data gradient is related to the weight, the weight gradient is related to the data, and the data gradient and weight gradient of the previous layer are related to the data gradient of the latter layer.

Next, the gradient disappearance and gradient explosion can be explained:

  • Gradient disappearance: When the constructed neural network is very deep, the learning speed of different layers varies greatly. It is shown that the learning of the layer close to the output in the network is very good, and the learning of the layer close to the input is very slow. There are many reasons for this problem, such as improper weight initialization, or improper use of the activation function. It is easier to understand the activation function as an example.
    insert image description here
    If Sigmoid or Tanh is used as the activation function, they are characterized by a gradient less than 1. That means that the derivative of the activation function to the data is less than 1 every time it propagates to the next level during backpropagation, that is, ∂ f ( hj ) ∂ hj \frac{\partial f(\mathbf{h}_{j})} {\partial \mathbf{h}_{j}}hjf(hj)less than 1. And every time the data gradient is propagated to the previous level, it will be multiplied by ∂ f ( hj ) ∂ hj \frac{\partial f(\mathbf{h}_{j})}{\partial \mathbf{h}_{j} }hjf(hj), so the deeper the propagation, the smaller the final data gradient, and the smaller the corresponding weight gradient, which causes the gradient to disappear. So when building a network, we usually use the ReLU function as the activation function because its gradient is 1. Improper weight initialization, such as some weights being too small, can also cause this problem.
  • Gradient explosion: If the weight initialization sets some weight values ​​too large, then during backpropagation, the data gradient will become larger for each level of forward propagation, and the corresponding weight gradient will also be superimposed and larger, so it will cause The weight gradient is too large.

In summary, the weight initialization and activation function are the main reasons for the gradient disappearance and gradient explosion, so try to distribute the weight initialization value around 1 when initializing the weight.

2.1 Internal covariance transfer

Internal covariance transfer is mentioned in the Batch Normalization paper. As mentioned above, the deep neural network involves the superposition of many layers. The parameter update of each layer will cause the input data distribution of the upper layer to change. Through the layer-by-layer superposition, the input distribution of the upper layer will change very drastically, which makes the higher layer need to constantly to re-adapt the underlying parameter updates.

That is to say, the data we input will undergo a nonlinear transformation through each layer of the network until the last layer. At this time, the distribution of the input data has been changed, but the ground truth will not change, which is As a result, the lower neurons in the network need to constantly adapt to update parameters to adapt to the new data distribution, and the update of each layer will affect the changes of the next layer, so the optimizer parameter setting needs to be very cautious.

3.1 Batch Normalization principle

The following is an analysis of the principle of Batch Normalization. In order to solve the internal covariance transfer, the input of each layer of the network must satisfy independent and identical distribution, and this is the approach of Batch Normalization.

Take Convolutional Neural Networks as an example. Suppose a certain layer of our network has kkk neurons, its previous layer hasjjj neurons, then thejjthThe output of layer j is [B, j, H1, W1], where B is Batch_Size, and j is the number of channels output by this layer. No.jjThe output of layer j is passed as input to thekkthk layer, and thekkthlayer k haskkk neurons, equivalent to the number of output channels of this layer iskkk , that is, thekkthThe weight dimension of each neuron in the k layer is [j, S, S], S is the size of the convolution kernel, and the weight of each neuron is convoluted with the input [B, j, H1, W1] to obtain the dimension [ B, 1, H2, W2], and a total ofkkk such neurons, so thekkthThe overall dimension of the k -layer output is [B, k, H2, W2]. You can look at the picture below to deepen your understanding:jj
insert image description here
in the picture aboveThe output of layer j is [B, 4, H1, W1], which is passed as input to thekthlayer k , thekthThe k layer has two neurons, and the weight dimension of each neuron is [4, S, S], but the weight of each neuron is convoluted with the input, and the result is [B, 1, H2, W2] , then cat the results of the two neurons to get the overall result [B, 2, H2, W2].

Batch Normalization is the role of kkOn the output of the k layer, continue to assume the kkthlayer k haskkk neurons, Batch_Size ismmm , meansmmm data, so thekkthThe dimension of the k layer output is[ m , k , H , W ] [m,k,H,W][m,k,H,W ] , equivalent to a total ofmmm data, each data haskkk channels, each channel is[ H , W ] [H,W][H,W ] matrix, and Batch Normalization is formmEach dimension of m data is regularized, as shown in the figure below:
insert image description here
The above is the position where we add it when we use BN. Generally, a Conv layer is followed by a BN layer, and then an activation layer such as ReLU. Let's look at the specific formula of BN again.

Continuing with the example mentioned above, mmm data passkkthThe k layer gets the dimension[ m , k , H , W ] [m,k,H,W][m,k,H,W ] output, namelymmm data, each data haskkk channels, each channel is[ H , W ] [H,W][H,W ] matrix. Batch Normalization is performed on the output, that is,mmIf m is an infinitesimal equation, solve the problem:
µ 1 = 1 m ∑ i = 1 mx 1 i σ 1 2 = 1 m ∑ i = 1 m ( x 1 i − µ 1 ) 2 x ^ 1 ← x 1 i − μ 1 σ 1 2 + ϵ y 1 ← γ 1 x 1 ^ + β 1 ≡ BN γ 1 , β 1 ( x 1 ) \mu_{1} = \frac{1}{m} \sum_{ i=1}^{m}x_{1i}\\\sigma_{1}^{2} = \frac{1}{m}\sum_{i=1}^{m}\left(x_{1i} -\mu_{1}\right)^{2} \\\hat{x}_1 \leftarrow \frac{x_{1i}-\mu_{1}}{\sqrt{\sigma_{1}^{2} +\epsilon}} \\ y_{1} \leftarrow \gamma_1 \hat{x_{1}}+\beta_1 \equiv B N_{\gamma_1, \beta_1}\left(x_{1}\right)m1=m1i=1mx1 ip12=m1i=1m(x1 im1)2x^1p12+ϵ x1 im1y1c1x1^+b1BNc1, b1(x1)
wherex 1 x_1x1Represents the first channel of the entire Batch, x 1 i x_{1i}x1 iIndicates the iiThe first channel of i data. This operation can be divided into two steps:

  • Standardization: first to mmm mm_m performs Standardization to get the distribution of zero mean unit variancex ^ 1 \hat{x}_1x^1;
  • scale and shift: then for x ^ 1 \hat{x}_1x^1Do scale and shift, scale and translate to a new distribution y 1 y_1y1, with a new mean variance γ 1 \gamma_1c1

γ 1 \gamma_1 c1and β 1 \beta_1b1It is the scale and shift parameters to be learned, used to control y 1 y_1y1variance and mean of .

Guess you like

Origin blog.csdn.net/weixin_45453121/article/details/130678413