Why Normalization
Internal Covariate Shift(ICS):数据尺度/分布异常,导致训练困难
常见的Normalization
1.Batch Normalization(BN)
2.Layer Normalization(LN)
3.Instance Normalization(IN)
4.Group Normalization(GN)
相同点
$$
y_{i} \leftarrow \gamma \widehat{x}{i}+\beta \equiv \mathrm{N}{\gamma, \beta}\left(x_{i}\right)
不同点
均值和方差求取方式
1.Layer Normalization
起因:BN不适合用于变长的网络,如RNN
思路:逐层计算均值和方差
注意事项:
1.不再有running_mean 和 running_var
2.gamma 和 beta 为逐元素、逐特征的
nn.LayerNorm
主要参数:
normalized_shape:该层特征形状
eps:分母修正项
elementwise_affine:是否需要affine transform
2.Instance Normalization
起因:BN在图像生成(Image Ganeration)中不适用
思路:==逐Instance(channel)==计算均值和方差
计算方式 逐通道的
nn.InstanceNorm
主要参数:
num_features:一个样本特征数量(最重要)
eps:分母修正项
momentum:指数加权平均估计当前mean/var
affine:是否需要affine transform
track_running_stats:是训练状态,还是测试状态
3.Group Normalization
起因:小batch样本中,BN估计的值不准
思路:数据不够,通道来凑
注意事项
1.不再有running_mean和running_var
2.gamma 和beta 为逐通道(channel)的
应用场景 大模型(小batch size)任务
nn.GroupNorm
主要参数
num_groups 分组数 通产设为2的n次方
num_channels 通道数(特征数)
eps 分母修正项
affine 是否需要affine transform
小结:
BN LN IN GN 都是为了克服Internal Covariate shift(ICS)
加减乘除
减均值 除标准差 乘γ 加β
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ======================================== nn.layer norm
# flag = 1
flag = 0
if flag:
batch_size = 2
num_features = 3
features_shape = (2,2)
# features_shape = (3, 4)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 4D
# feature_maps_bs shape is [8, 6, 3, 4], B * C * H * W
ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=True)
# ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=False)
# ln = nn.LayerNorm([6, 3, 4])
# ln = nn.LayerNorm([6, 3])
output = ln(feature_maps_bs)
print("Layer Normalization")
print(ln.weight.shape)
print(feature_maps_bs[0, ...])
print(output[0, ...])
# ======================================== nn.instance norm 2d
# flag = 1
flag = 0
if flag:
batch_size = 3
num_features = 3
momentum = 0.3
features_shape = (2, 2)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 4D
print("Instance Normalization")
print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))
instance_n = nn.InstanceNorm2d(num_features=num_features, momentum=momentum)
for i in range(1):
outputs = instance_n(feature_maps_bs)
print(outputs)
# print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
# print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
# print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
# print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))
# ======================================== nn.grop norm
flag = 1
# flag = 0
if flag:
batch_size = 2
num_features = 4
# 设置分组数时一定是能被整除的 通常设置为2的N次幂
num_groups = 2 # 3 Expected number of channels in input to be divisible by num_groups
features_shape = (2, 2)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps * (i + 1) for i in range(batch_size)], dim=0) # 4D
# 分组数 有几个特征图
gn = nn.GroupNorm(num_groups, num_features)
outputs = gn(feature_maps_bs)
print("Group Normalization")
print(gn.weight.shape)
print(outputs[0])