Understand BatchNorm in CNN in one article

1 Introduction

This paper focuses on BatchNormthe definition and related characteristics of , and introduces its detailed implementation and specific applications. Hope it can help you to understand it better.

Well, without further ado, let's get started!

2. What is BatchNorm?

BatchNorm is a network layer proposed in 2015. This layer has the following characteristics:

  • Ease of training: Since the distribution of network weights varies much less with this layer, we can use higher learning rates. We are less erratic in the direction of convergence during training, so that we can move faster towards loss convergence.

  • Boost regularization: Although the network encounters the same training samples every epoch, the normalization is different for each mini-batch, thus changing its value slightly each time.

  • Improved accuracy: Probably due to the combination of the previous two points, the paper mentions that they achieved better accuracy than the state-of-the-art results at the time.

3. How does BatchNorm work?

BatchNormWhat it does is ensure that the input it receives has mean 0 and standard deviation 1.
The algorithm introduced in this article is as follows:
insert image description here
The following is my own implementation with pytorch:

import numpy as np
import torch
from torch import nn
from torch.nn import Parameter

class BatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.gamma = Parameter(torch.Tensor(num_features))
        self.beta = Parameter(torch.Tensor(num_features))
        self.register_buffer("moving_avg", torch.zeros(num_features))
        self.register_buffer("moving_var", torch.ones(num_features))
        self.register_buffer("eps", torch.tensor(eps))
        self.register_buffer("momentum", torch.tensor(momentum))
        self._reset()
    
    def _reset(self):
        self.gamma.data.fill_(1)
        self.beta.data.fill_(0)
    
    def forward(self, x):
        if self.training:
            mean = x.mean(dim=0)
            var = x.var(dim=0)
            self.moving_avg = self.moving_avg * momentum + mean * (1 - momentum)
            self.moving_var = self.moving_var * momentum + var * (1 - momentum)
        else:
            mean = self.moving_avg
            var = self.moving_var
            
        x_norm = (x - mean) / (torch.sqrt(var + self.eps))
        return x_norm * self.gamma + self.beta

Here it is supplemented as follows:

  • We have different behaviors of BatchNorm during training and inference. During training, we record exponential moving averages of the mean and variance for later use at inference time. The reason for this is that when processing batches during training, we can obtain better estimates of the mean and variance of the input over time, which can then be used for inference. Using the mean and variance of the input batches during inference will be less accurate as their size may be much smaller than those used in training, the law of large numbers comes into play here.

4. When to use Batchnorm?

This always seems to help, so there's no reason not to use it. Usually it occurs between fully connected layers/convolutional layers and activation functions. But it was also argued that it is better to put it after the activation layer. I can't find any papers on using it after the activation function, so the safest bet is to use it before the activation function, as everyone does.

5. Summary of some skills

List the skills of BatchNorm in practical applications and summarize them as follows:

  • We know that a trained network contains moving averages and variances of the dataset used to train it, and this can be a problem. During transfer learning, we usually freeze most of the layers, and if we are not careful, the BatchNorm layer will also be frozen, which means that the applied moving average belongs to the original dataset, not the new dataset. Unfreezing the BatchNorm layer is a good idea and will allow the network to recalculate the moving average and variance on its own dataset.

Guess you like

Origin blog.csdn.net/sgzqc/article/details/127952552