Detailed explanation of Batch Normalization and pytorch experiment

Batch Normalization was proposed by the Google team in the 2015 paper " Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift ". This method can accelerate the convergence of the network and improve the accuracy. Although there are many related articles on the Internet, they are basically put on the formulas in the papers and talk about them in general, and how bn really works is rarely mentioned. This article is mainly divided into the following parts:

(1) The principle of BN

(2) Use pytorch to verify the views of this article

(3) Points to pay attention to when using BN (BN is a pit if it is not used well)

1. Batch Normalization principle

In the process of image preprocessing, we usually standardize the image, which can speed up the convergence of the network. As shown in the figure below, for Conv1, the input is a feature matrix that satisfies a certain distribution, but for Conv2, the input feature A map does not necessarily satisfy a certain distribution law ( note that satisfying a certain distribution law here does not mean that the data of a certain feature map must satisfy the distribution law, but theoretically means that the data of the feature map corresponding to the entire training sample set must satisfy the distribution law. distribution law ). The purpose of our Batch Normalization is to make our feature map satisfy the distribution law with a mean of 0 and a variance of 1.

You should still be confused when you see this, don't panic, drink water and take your time. The following is the original words intercepted from the original paper, pay attention to the part marked in yellow:

"For an input x with d dimensions, we will normalize each dimension of it." Assuming that our input x is a color image with three RGB channels, then d here is the channels of the input image, ie d=3, x=(x^{(1)}, x^{(2)}, x^{(3)}), which x^{(1)}represents the feature matrix corresponding to our R channel, and so on. Normalization processing is to process our R channel, G channel, and B channel separately. The above formula does not need to be read. The original text provides a more detailed calculation formula:

We just said that the feature map should satisfy a certain distribution law. In theory, it means that the data of the feature map corresponding to the entire training sample set should meet the distribution law , that is to say, the feature map of the entire training set should be calculated and then standardized. It is obviously impossible for a large data set, so the paper talks about Batch Normalization, that is, we calculate the feature map of a Batch data and then normalize it (the larger the batch, the closer to the distribution of the entire data set, the better the effect) . According to the formula in the above figure, we can know that it \mu _{\ss }represents the mean value of each dimension (channel) of the feature map we calculated. Note \mu _{\ss }that a vector is not a value , \mu _{\ss }and each element of the vector represents the mean value of a dimension (channel). \sigma_{\ss }^{2}Represents the variance of each dimension (channel) of the feature map we calculated. Note \sigma_{\ss }^{2}that a vector is not a value . \sigma_{\ss }^{2}Each element of the vector represents the variance of a dimension (channel), and then calculates the value obtained after normalization according to \mu _{\ss }and . \sigma_{\ss }^{2}The following figure gives an example of calculating the mean \mu _{\ss }and variance \sigma_{\ss }^{2}:

The above figure shows the calculation process of Batch Normalization with a batch size of 2 (two pictures). It is assumed that feature1 and feature2 are feature matrices obtained by image1 and image2 after a series of convolution pooling, and the channel of feature is 2. Then x^{(1)}represent the data of channel1 of all features of the batch, and similarly x^{^{(2)}}represent the data of channel2 of all features of the batch. Then calculate the mean and variance of the sum, respectively x^{(1)}, x^{^{(2)}}to get our \mu _{\ss }sum \sigma_{\ss }^{2}two vectors. Then calculate the value of each channel separately\ epsilon according to the standard deviation calculation formula (the formula in the formula is a small constant to prevent the denominator from being zero). In the process of training the network, we train through a batch of data, but we usually input a picture for prediction in the prediction process. At this time, the batch size is 1. If the above method is used to calculate Mean and variance are meaningless. Therefore, we need to continuously calculate the mean and variance of each batch during the training process, and use the moving average method to record the statistical mean and variance. After training, we can approximate the statistical mean and variance. It is equal to the mean and variance of the entire training set . Then during our validation and prediction process, we use the statistical mean and variance for normalization .

\gammaCareful students will find that there are two parameters in the original paper formula \beta? Yes, it \gammais used to adjust the variance of the numerical distribution, and it \betais used to adjust the position of the numerical mean. These two parameters are learned during backpropagation, \gammathe default value is 1, and \betathe default value is 0.

2. Experiment with pytorch

Do you think you understand it all? not necessarily. Just said that in our training process , the mean \mu _{\ss }and variance \sigma_{\ss }^{2}are obtained by calculating the current batch of data as a \mu _{now}sum \sigma _{now}^{2}, and the mean and variance used in our validation and prediction process are a statistic and recorded as a \mu _{statistic}sum \sigma _{statistic}^{2}. \mu _{statistic}The specific update strategy of sum \sigma _{statistic}^{2}is as follows, where momentum takes 0.1 by default:

\large \mu _{statistic+1}=(1-momentum)*\mu _{statistic}+momentum*\mu _{now}

\large \sigma _{statistic+1}^{2}=(1-momentum)*\sigma _{statistic}^{2}+momentum*\sigma _{now}^{2}

It should be noted here that the overall standard deviation\large \sigma _{now}^{2} is used in the bn processing of the current batch of features in pytorch, and the calculation formula is as follows:

\bg_white \large \sigma _{now}^{2}=\frac{1}{m}\sum_{i=1}^{m}(x_{i}-\mu _{now})^{2}

When updating statistics \large \sigma _{statistic}^{2}, the \large \sigma _{now}^{2}sample standard deviation is used , and the calculation formula is as follows:

\bg_white \large \sigma _{now}^{2}=\frac{1}{m-1}\sum_{i=1}^{m}(x_{i}-\mu _{now})^{2}

The following is the test I did using pytorch, the code is as follows:

(1) The bn_process function is a custom bn processing method to verify whether the result is consistent with the official bn processing method. Calculate the mean and standard deviation of each dimension of the input batch data (the dimension here is the channel dimension) in bn_process (the standard deviation is equal to the square root of the variance), and then normalize each dimension of the feature by the calculated mean and overall standard deviation , then use the mean and sample standard deviation to update the statistical mean and standard deviation.

(2) The initial statistical mean is a vector with 0 elements, and the number of elements is equal to the channel depth; the initial statistical variance is a vector with 1 elements, and the number of elements is equal to the channel depth, initialization \gamma= 1, \beta= 0.

import numpy as np
import torch.nn as nn
import torch


def bn_process(feature, mean, var):
    feature_shape = feature.shape
    for i in range(feature_shape[1]):
        # [batch, channel, height, width]
        feature_t = feature[:, i, :, :]
        mean_t = feature_t.mean()
        # 总体标准差
        std_t1 = feature_t.std()
        # 样本标准差
        std_t2 = feature_t.std(ddof=1)

        # bn process
        # 这里记得加上eps和pytorch保持一致
        feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)
        # update calculating mean and var
        mean[i] = mean[i] * 0.9 + mean_t * 0.1
        var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1
    print(feature)


# 随机生成一个batch为2,channel为2,height=width=2的特征向量
# [batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
# 初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())

# 注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)

bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
print(output)

First of all, I set a breakpoint at the end to debug, and check the statistical mean and variance obtained after the official bn has processed the feature. We can find that the running_mean and running_var of the official bn are exactly the same as the calculate_mean and calculate_var calculated by ourselves (only the precision is different).

Then we print out the output obtained by the custom bn_process function and the output obtained by using the official bn process, obviously the results are the same (only the precision is different):

3. Problems that need to be paid attention to when using BN

(1) Set the traning parameter to True during training, and set the training parameter to False during validation. In pytorch, it can be controlled by the model.train() and model.eval() methods of creating a model.

(2) The batch size should be set as large as possible, and the performance may be very bad after the setting is small. The larger the setting, the closer the mean and variance of the whole training set to the mean and variance of the entire training set.

(3) It is recommended to place the bn layer between the convolution layer (Conv) and the activation layer (such as Relu), and the convolution layer should not use bias bias, because it is useless, refer to the following figure for reasoning, even if bias bias is used The result is the same\bg_white \large y_{i}^{b}=y_{i}

Finally, Mr. Li Hongyi's video explanation on batch normalization is given:

Li Hongyi Deep Learning (2017)_bilibili_bilibili

Guess you like

Origin blog.csdn.net/qq_37541097/article/details/104434557#comments_20942083