[Standardization method] (1) Batch Normalization principle analysis, code reproduction, complete Pytorch code attached

Hello everyone, today I will share with you the common standardization methods in deep learning. First, I will introduce the most commonly used Batch Normalization, and reproduce the code from the perspective of mathematical formulas.


1. Principle Analysis

During deep network training, the update of the training parameters of the previous layer will lead to changes in the input data distribution of the latter layer . Take the second layer of the network as an example: the input of the second layer of the network is calculated from the parameters and input data of the first layer. The parameters of the first layer have been changing throughout the training process, so it will inevitably cause changes in the input data distribution of each subsequent layer.

Since the model parameters are constantly modified and propagated forward, the input distribution of each layer is constantly changing, that is, the input data distribution of each layer of the network is always changing . Researchers refer to the change of data distribution in the middle layer of the network during the training process as internal covariate shift ( ICS ), which requires that a smaller learning rate must be used during model training, and the initial value of the weight needs to be carefully selected. . ICS slows down training and also causes vanishing gradient problems when using saturated non-linear activation functions (such as sigmoid, which saturates gradients to 0 on both positive and negative sides).

For the common phenomenon of internal covariance changes, the solution is to normalize the input of each layer. That is, the batch normalization algorithm (Batch Normalization, BN ), the BN algorithm mainly includes the following three steps:

(1) Calculate the statistical value . Computes the statistics required for normalization on the mini-batch samples, including mean and variance . Suppose the input is x \in R^{m*d}, where m refers to the size of the current batch (batch size) , that is, how many training samples are there in the current batch. d refers to the size of the input feature map .

E\left(x^{k}\right)=\frac{1}{m} \sum_{i=1}^{m} x_{i}^{k}

\operatorname{Var}\left[x^{k}\right] \leftarrow \frac{1}{m} \sum_{i=1}^{m}\left(x_{i}^{k}-E\left[x^{k}\right]\right)^{2}

(2) Normalization operation . Treat each element in the input vector as an independent random variable and normalize it separately . The variables in the vector are independent, and there is no covariance matrix. This kind of normalization can still speed up the convergence when the variables are related, and the following formula is used for processing, that is, approximate whitening processing. For d-dimensional input data  x=(x^{(1)}...x^{(d)}), normalizes each dimension .

\hat x=\frac{x^k-E[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}

(3) Linear transformation . Only normalizing the input may change the original characteristics or distribution of the input. For example, adding a batch normalization algorithm to the sigmoid function may change the input from nonlinear to linear . To solve this problem, learnable parameter gains  \gammaand biases  can be used \betato fit the original distribution .

y^{(k)}=\gamma^{(k)}\cdot\hat{x}^{(k)}+\beta^{(k)}

Among them, \gamma^{(k)}=Var[x^{(k)}]and \beta^{(k)}=E[x^{(k)}]when , theoretically the same distribution as the input can be obtained. In the experiment, it is generally initialized as gain \gamma=1and offset\beta=0 . The purpose of adding linear transformation here is to make the BN "deliberately" added due to training possible to restore the original input .

Based on the above three points, the overall process of the batch normalization algorithm is as follows:

enter:

x Represents a small batch  B=x_{1...m}of m samples; the parameters to be learned are gain  \gammaand bias\beta

output:

y_i=BN_{\gamma, \beta}(x_i)

\mu_B\leftarrow\frac{1}{m}\sum_{i=1}^m x_i

\sigma_B^2\leftarrow\frac{1}{m}\sum_{i=1}^m(x_i-\mu_B)^2

\hat{x}_{i}\leftarrow\frac{x_{i}-\mu_{B}}{\sqrt{\sigma_{p}^{2}+\mathcal{E}}}

y_i\leftarrow\gamma\hat{x_i}+\beta\equiv\operatorname{BN}_{\gamma,\beta}(x_i)


Regarding the parameter update of the BN transformation, it can be calculated by derivation of the chain rule.

The above only involves the operation of the batch normalization algorithm during training . During the test , the batch normalization is normalized by using the unbiased estimation of the mean and variance during training .

The batch normalization algorithm changes from the original input x to BN(x). The normalized input needs to consider small batches during training , but only the relationship between input and output needs to be established during testing . So the calculation formula of the final test  phase  is as follows:\mu\sigma

\operatorname{E}[x]\leftarrow\operatorname{E}_\beta[u_\beta]

Var[x]\leftarrow\frac{m}{m-1}\text{E}_{\beta}[\sigma_{\beta}^2]

That is, the mean value of all small batches during training is directly calculated as the mean value of the test set during testing , and the unbiased estimate of the variance of each small batch is used for the variance . Then for a test sample x, the batch normalization operation in the final test phase is:

y=\frac{\gamma}{\sqrt{Var[x]+\varepsilon}}\cdot x+(\beta-\frac{\gamma E[x]}{\sqrt{Var[x]+\varepsilon}})


The batch normalization algorithm normalizes the input with the mean variance to make the statistical distribution of the input consistent, which can reduce ICS and accelerate neural network training . In general, the main functions of BN are as follows:

(1) The BN transformation is differentiable. Through the BN transformation, the ICS of the input distribution can be weakened , and the mean and variance of the input of each layer can be kept stable . Finally, a linear transformation is added to make the BN transformation equivalent to the original transformation of the network.

(2) Make the gradient less affected by the training parameters and their initial values .

(3) The batch normalization algorithm reduces the probability that the value input to the activation function is stuck in the saturation area , so that a function that is easy to saturate, such as sigmoid, can be used.


The batch normalization algorithm can be used for both feed-forward neural networks and convolutional neural networks. When applying batch normalization techniques in convolutional neural networks one should be careful about who and where to use.

Generally speaking, for inputx\in R^{m\times d\times h \times w } . Among them, m refers to the size of the current batch (batchsize); d refers to the size of the input feature map; h and w respectively represent the height and width of the image in the image classification task.

WeightW \in R^{d\times n\times F_h \times F_w } ; d, n represent the size of the two feature maps connected by weight; F_h, F_wrepresent the size of the filter. The output after the convolution operation is conv(Wx)+bmore in line with the symmetrical and non-sparse distribution in the distribution, that is, the distribution is more in line with the Gaussian distribution, so the  conv(Wx)+bbatch normalization algorithm is used on it.

When using the batch normalization algorithm, considering that the bias parameter b will be normalized by the mean value after passing through the BN layer , and there is another  \betaparameter after the BN layer as a bias item, so the parameter b can be omitted , that is, the original The network of is transformed into the following form, where g represents the activation function.

z=g(BN(conv(Wx)))


2. Code implementation

Based on the above theoretical derivation, the code of Batch Normalization is as follows. Construct an input tensor of shape=[b,c,w,h]=[2,3,2,2] to test the BN layer, and do standardization.

import torch
from torch import nn

class BN(nn.Module):
    # 初始化
    def __init__(self, channels:int, 
                 eps:float=1e-5, momentum:float=0.1,
                 affine:bool=True, track_running_stats:bool=True):
        super(BN, self).__init__()

        self.channels = channels  # 输入特征数
        self.eps = eps            # 防止分母为0
        self.momentum = momentum  # 指数平滑
        self.affine = affine      # 是否对norm值做缩放和平移
        self.track_running_stats = track_running_stats  # 是否计算移动平均值或均值或方差

        if self.affine:  # 为每个特征生成一个可训练的缩放参数和平移参数
            self.scale = nn.Parameter(torch.ones(channels))  # [c]
            self.shift = nn.Parameter(torch.zeros(channels))  # [c]

        # 定义一组参数,模型训练时不会更新(即调用 optimizer.step()后该组参数不会变化,只可人为地改变它们的值)
        if self.track_running_stats:  # 存放均值和方差的指数移动平均
            self.register_buffer('exp_mean', torch.zeros(channels))
            self.register_buffer('exp_var', torch.ones(channels))

    # 前向传播
    def forward(self, x: torch.Tensor, training=True):
        x_shape = x.shape  # 输入特征的维度[b,c,w,h]
        batch_size = x_shape[0]  # 每个step训练batch_size个样本
        # 如果输入特征的深度和预置的输入通道不同就报错
        assert self.channels == x.shape[1]
        # [b,c,w,h]-->[b,c,w*h]
        x = x.view(batch_size, self.channels, -1)
        
        # 训练模式下或没有跟踪指数移动平均,模型更新参数
        if self.training or not self.track_running_stats:
            # [b,c,w*h]-->[c]
            mean = x.mean(dim=[0,2])  # 计算每个通道的均值,一个通道一个num,计算axis=0和2的均值
            mean_x2 = (x**2).mean(dim=[0,2])  # 平方的均值
            var = mean_x2 - mean**2  # 每个通道的方差
        
            # 更新指数移动平均
            if self.training and self.track_running_stats:
                self.exp_mean = (1-self.momentum)*self.exp_mean + self.momentum*mean
                self.exp_var = (1-self.momentum)*self.exp_var + self.momentum*var
        
        else:  # 测试模式下不更新均值和方差
            mean = self.exp_mean
            var = self.exp_var

        # 标准化 --> [b,c,w*h]
        x_norm = (x-mean.view(1,-1,1)) / torch.sqrt(var+self.eps).view(1,-1,1)
        # 可学习的缩放参数和平移参数 [b,c,w*h]
        if self.affine:
            x_norm = self.scale.view(1,-1,1)  * x_norm + self.shift.view(1,-1,1)
        
        # [b,c,w*h]-->[b,c,w,h]
        return x_norm.view(x_shape)

# ------------------------------ #
# 测试
# ------------------------------ #

if __name__ == '__main__':

    x = torch.linspace(0, 23, 24, dtype=torch.float32)
    x = x.reshape([2,3,2,2])  # [b,c,w,h]
    bn = BN(channels=3)  # 实例化
    # 前向传播
    x = bn(x)
    print(x)

Guess you like

Origin blog.csdn.net/dgvv4/article/details/130567501