[Standardization method] (3) Group Normalization principle analysis, code reproduction, with Pytorch code

Today, I would like to share with you the standardization method commonly used in deep learning, Group Normalization data group normalization, introduce the mathematical principle to you, and reproduce it with Pytorch.

Group Normalization paper address: https://arxiv.org/pdf/1803.08494.pdf


1. Principle introduction

In large-scale computer vision applications such as target detection and video classification, due to the limitation of computer memory, a small number of samples must be set, but a small sample size will inevitably lead to an impact on the performance of batch normalization.

Group Normalization ( GN ) is an improved algorithm proposed for the weakness of the batch normalization algorithm, which has a strong dependence on the batch size . Because the calculation of BN layer statistics is related to the size of the batch , when the batch becomes smaller, it is obvious that the calculation of the statistical mean and variance will be less accurate and stable, and eventually there will be a small batch of high error rates. .

Group normalization GN is between layer normalization LN and instance normalization IN, for images with input size [N,C,H,W], N represents the batch size, C represents the number of input channels, H, W represent the height and width of the input image.

Group normalization first divides the input channel C into G groups , and then performs a normalization operation on each group, that is, first [N,C,H,W]changes [N,G,\frac{C}{G},H,W], and the normalized dimension is[\frac{C}{G},H,W] .

In fact, when G is equal to 1 , GN and LN are calculated in the same way when all input channels are one group , and when G is equal to C , GN and IN are calculated in the same way when one input channel is one group .

The figure above is a simple illustration of batch normalization algorithm BN, layer normalization algorithm LN, instance normalization IN and group normalization GN. The cube in the figure is three-dimensional, and the blue square is the area where each algorithm calculates the mean and variance .

Where C represents the number of channels, N is the batch size, H, W are the height and width, and the size of the third dimension is H*W, so that the input can be represented by a three-dimensional graph. It can be seen from the above figure that only the calculation of BN is related to the batch size N , the calculation of LN, IN and GN is performed on a single sample , and LN, IN and GN can be converted to each other.

Generally speaking, the way of normalization is as follows:

\mu_i=\frac{1}{m}\sum_{k\in S_i}x_k

\sigma_i=\sqrt{\frac{1}{m}\sum_{k\in S_i}\left(x_k-\mu_i\right)^2+\epsilon}

S_i is the calculation area for the mean and variance, in BN there are:

S_i=\left\{k|k_C=i_C\right\}

In LN:

S_i=\left\{k|k_N=i_N\right\}

In GN:

S_i=\{k\mid k_N=i_N,floor(\frac{k_C}{C/G})=floor(\frac{i_C}{C/G})\}

Pros: Does not depend on batch size.

Cons: Performance is not as good as BN when the batch size is large.


2. Code display

import torch 
from torch import nn

class GN(nn.Module):
    # 初始化
    def __init__(self, groups:int, channels:int, 
                 eps:float=1e-5, affine:bool=True):
        super(GN, self).__init__()
        # 通道数要整除组数
        assert channels % groups == 0, 'channels should be evenly divisible by groups'
        self.groups = groups  # 把通道分成多少组
        self.channels = channels  # 通道数
        self.eps = eps  # 防止分母为0
        self.affine = affine  # 是否使用可学习的线性变化参数
        if self.affine:
            self.scale = nn.Parameter(torch.ones(channels))  # 缩放因子
            self.shift = nn.Parameter(torch.zeros(channels))  # 偏置
    # 前向传播
    def forward(self, x: torch.Tensor):
        x_shape = x.shape  # 输入特征的维度 [b,c,w,h]
        batch_size = x_shape[0]  # 样本量
        assert self.channels == x.shape[1]  # 预设通道数和输入特征的通道数要保持一致
        # [b,c,w,h]-->[b,g,w*h*c/g]
        x = x.view(batch_size, self.groups, -1)
        # 在最后一个维度上做标准化
        mean = x.mean(dim=[-1], keepdim=True)  # [b,g,1]
        mean_x2 = (x**2).mean(dim=[-1], keepdim=True)  # [b,g,1]
        var = mean_x2 - mean**2
        x_norm = (x-mean) / torch.sqrt(var+self.eps)  # [b,g,w*h*c/g]
        # 线性变化
        if self.affine:
            x_norm = x_norm.view(batch_size, self.channels, -1)  # [b,c,w*h]
            x_norm = self.scale.view(1,-1,1)* x_norm + self.shift.view(1,-1,1)  # [1,c,1]*[b,c,w*h]+[1,c,1]
        # [b,c,w*h]-->[b,c,w,h]
        return x_norm.view(x_shape)

# ---------------------------------- #
# 验证
# ---------------------------------- #

if __name__ == '__main__':
    # 构造输入层
    x = torch.linspace(0, 47, 48, dtype=torch.float32)  # 构造输入层
    x = x.reshape([2,6,2,2])  # [b,c,w,h]
    # 实例化
    gn = GN(groups=3, channels=6)
    # 前向传播
    x = gn(x)
    print(x.shape)

Guess you like

Origin blog.csdn.net/dgvv4/article/details/130579409