1. Batch Normalization
Ioffe S, Szegedy C. Batch normalization: Accelerating deep network training by reducing internal covariate shift[C]//International conference on machine learning. PMLR, 2015: 448-456.
具体做法如上图所示,计算除了通道维的均值和方差从而进行标准化,即batch中每个样本对应位置的特征做标准化(以均值为例,每个样本第n个通道的特征图计算均值,得到一个均值特征图,再对高和宽再算均值最终得到n个1x1的均值特征图)。其中 γ \gamma γ和 β \beta β是可学习的超参数。BN的提出是为了加快神经网络的训练并且解决Internal Covariate Shift的问题。使用了BN之后我们可以采用较大的学习率。
此外,BN也可以提供正则化的作用,从而减少Dropout的使用。
pytorch简洁代码实现:
import torch
import torch.nn as nn
BN = nn.BatchNorm2d(num_features=2)
a = torch.tensor([[[1, 2, 3],[4, 5, 6], [7, 8, 9]],
[[0, 1, 2],[3, 4, 5], [6, 7, 20]]], dtype=torch.float32)
a = a.unsqueeze(0)
a = a.repeat(3, 1, 1, 1) # [3,2,3,3]
print(BN(a))
自定义BN:
结果是一样的
class MyBN(nn.Module):
def __init__(self, num_features):
super(MyBN, self).__init__()
self.scale = nn.Parameter(torch.ones(size=(num_features,)))
self.shift = nn.Parameter(torch.zeros(size=(num_features,)))
self.eps = 1e-5
def forward(self, x):
# x[N,C,H,W]
mean = torch.mean(x,dim=[0,2,3], keepdim=True)
var = torch.mean(x**2, dim=[0,2,3], keepdim=True) - mean**2
x = (x - mean) / (torch.sqrt(var + self.eps))
x = x * self.scale.reshape(-1,1,1) + self.shift.reshape(-1,1,1)
return x
MBN = MyBN(2)
print(MBN(a))
2. Layer Normalization
Ba J L, Kiros J R, Hinton G E. Layer normalization[J]. arXiv preprint arXiv:1607.06450, 2016.
在batch size过小的情况下,BN的效果往往不那么理想,并且很难应用于NLP的任务,因为NLP的句子末尾通常会有填充的空白token,因此batch中每个样本在末尾相对应的特征做BN完全没有意义。为了解决这一问题,LN被提出,原理很简单,就是每个样本自己做标准化即可。
pytorch简洁代码实现:
LN = nn.LayerNorm(normalized_shape=10)
a = torch.tensor([[[1,2,3,4,5,6,7,8,9,10],
[0,0,1,1,2,2,3,3,4,4]]],dtype=torch.float32)
print(LN(a))
自定义LN:
class MyLN(nn.Module):
def __init__(self, normalized_shape):
super(MyLN, self).__init__()
self.normalized_shape = normalized_shape
self.scale = nn.Parameter(torch.ones(normalized_shape))
self.shift = nn.Parameter(torch.zeros(normalized_shape))
self.eps = 1e-5
def forward(self, x):
if isinstance(self.normalized_shape, list):
dim = [-(i+1) for i in range(len(self.normalized_shape))]
else:
dim = -1
mean = torch.mean(x, dim=dim, keepdim=True)
var = torch.mean(x**2, dim=dim, keepdim=True) - mean**2
x = (x - mean) / (torch.sqrt(var + self.eps))
x = x * self.scale + self.shift
return x