PyTorch学习笔记(28) BN LN IN GN

Why Normalization

Internal Covariate Shift(ICS):数据尺度/分布异常,导致训练困难

H 11 = i = 0 n X i W 1 i D ( H 11 ) = i = 0 n D ( X i ) D ( W 1 i ) = n ( 1 1 ) = n \begin{aligned} \mathrm{H}_{11}=& \sum_{i=0}^{n} X_{i} * W_{1 i} \\ \mathrm{D}\left(\mathrm{H}_{11}\right) &=\sum_{i=0}^{n} D\left(X_{i}\right) * D\left(W_{1 i}\right) \\ &=n *(1 * 1) \\ &=n \end{aligned}
std ( H 11 ) = D ( H 11 ) = n D ( H 1 ) = n D ( X ) D ( W ) = 1 \begin{array}{l} \operatorname{std}\left(\mathrm{H}_{11}\right)=\sqrt{\mathbf{D}\left(\mathrm{H}_{11}\right)}=\sqrt{n} \\ \mathbf{D}\left(\mathrm{H}_{1}\right)=\boldsymbol{n} * \boldsymbol{D}(\boldsymbol{X}) * \boldsymbol{D}(\boldsymbol{W})=\mathbf{1} \end{array}
D ( W ) = 1 n std ( W ) = 1 n D(W)=\frac{1}{n} \Rightarrow \operatorname{std}(W)=\sqrt{\frac{1}{n}}

常见的Normalization

1.Batch Normalization(BN)
2.Layer Normalization(LN)
3.Instance Normalization(IN)
4.Group Normalization(GN)

相同点

x ^ i x i μ B σ B 2 + ϵ \widehat{x}_{i} \leftarrow \frac{x_{i}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^{2}+\epsilon}}
$$
y_{i} \leftarrow \gamma \widehat{x}{i}+\beta \equiv \mathrm{N}{\gamma, \beta}\left(x_{i}\right)

不同点

均值和方差求取方式

1.Layer Normalization

起因:BN不适合用于变长的网络,如RNN
思路:逐层计算均值和方差

注意事项:

1.不再有running_mean 和 running_var
2.gamma 和 beta 为逐元素、逐特征的

nn.LayerNorm

主要参数:
normalized_shape:该层特征形状
eps:分母修正项
elementwise_affine:是否需要affine transform

2.Instance Normalization

起因:BN在图像生成(Image Ganeration)中不适用
思路:==逐Instance(channel)==计算均值和方差
计算方式 逐通道的

nn.InstanceNorm

主要参数:
num_features:一个样本特征数量(最重要)
eps:分母修正项
momentum:指数加权平均估计当前mean/var
affine:是否需要affine transform
track_running_stats:是训练状态,还是测试状态

3.Group Normalization

起因:小batch样本中,BN估计的值不准
思路:数据不够,通道来凑

扫描二维码关注公众号,回复: 9571489 查看本文章

注意事项

1.不再有running_mean和running_var
2.gamma 和beta 为逐通道(channel)的

应用场景 大模型(小batch size)任务

nn.GroupNorm

主要参数
num_groups 分组数 通产设为2的n次方
num_channels 通道数(特征数)
eps 分母修正项
affine 是否需要affine transform

小结:

BN LN IN GN 都是为了克服Internal Covariate shift(ICS)

加减乘除

减均值 除标准差 乘γ 加β



# -*- coding: utf-8 -*-

import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed


set_seed(1)  # 设置随机种子

# ======================================== nn.layer norm
# flag = 1
flag = 0
if flag:
    batch_size = 2
    num_features = 3

    features_shape = (2,2)
    # features_shape = (3, 4)

    feature_map = torch.ones(features_shape)  # 2D
    feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0)  # 3D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)  # 4D

    # feature_maps_bs shape is [8, 6, 3, 4],  B * C * H * W
    ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=True)
    # ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=False)
    # ln = nn.LayerNorm([6, 3, 4])
    # ln = nn.LayerNorm([6, 3])

    output = ln(feature_maps_bs)

    print("Layer Normalization")
    print(ln.weight.shape)
    print(feature_maps_bs[0, ...])
    print(output[0, ...])

# ======================================== nn.instance norm 2d
# flag = 1
flag = 0
if flag:

    batch_size = 3
    num_features = 3
    momentum = 0.3

    features_shape = (2, 2)

    feature_map = torch.ones(features_shape)    # 2D
    feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0)  # 3D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)  # 4D

    print("Instance Normalization")
    print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))

    instance_n = nn.InstanceNorm2d(num_features=num_features, momentum=momentum)

    for i in range(1):
        outputs = instance_n(feature_maps_bs)

        print(outputs)
        # print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
        # print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
        # print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
        # print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))


# ======================================== nn.grop norm
flag = 1
# flag = 0
if flag:

    batch_size = 2
    num_features = 4
    # 设置分组数时一定是能被整除的 通常设置为2的N次幂
    num_groups = 2   # 3 Expected number of channels in input to be divisible by num_groups

    features_shape = (2, 2)

    feature_map = torch.ones(features_shape)    # 2D
    feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0)  # 3D
    feature_maps_bs = torch.stack([feature_maps * (i + 1) for i in range(batch_size)], dim=0)  # 4D
    # 分组数 有几个特征图
    gn = nn.GroupNorm(num_groups, num_features)
    outputs = gn(feature_maps_bs)

    print("Group Normalization")
    print(gn.weight.shape)
    print(outputs[0])

发布了32 篇原创文章 · 获赞 0 · 访问量 454

猜你喜欢

转载自blog.csdn.net/qq_33357094/article/details/104649044