标准化(Normalization)

1. Batch Normalization

Ioffe S, Szegedy C. Batch normalization: Accelerating deep network training by reducing internal covariate shift[C]//International conference on machine learning. PMLR, 2015: 448-456.

在这里插入图片描述
具体做法如上图所示,计算除了通道维的均值和方差从而进行标准化,即batch中每个样本对应位置的特征做标准化(以均值为例,每个样本第n个通道的特征图计算均值,得到一个均值特征图,再对高和宽再算均值最终得到n个1x1的均值特征图)。其中 γ \gamma γ β \beta β是可学习的超参数。BN的提出是为了加快神经网络的训练并且解决Internal Covariate Shift的问题。使用了BN之后我们可以采用较大的学习率。

此外,BN也可以提供正则化的作用,从而减少Dropout的使用。

pytorch简洁代码实现:

import torch
import torch.nn as nn

BN = nn.BatchNorm2d(num_features=2)
a = torch.tensor([[[1, 2, 3],[4, 5, 6], [7, 8, 9]],
                  [[0, 1, 2],[3, 4, 5], [6, 7, 20]]], dtype=torch.float32)
a = a.unsqueeze(0)
a = a.repeat(3, 1, 1, 1)  # [3,2,3,3]
print(BN(a))

在这里插入图片描述
自定义BN:
结果是一样的

class MyBN(nn.Module):

    def __init__(self, num_features):
        super(MyBN, self).__init__()
        self.scale = nn.Parameter(torch.ones(size=(num_features,)))
        self.shift = nn.Parameter(torch.zeros(size=(num_features,)))
        self.eps = 1e-5

    def forward(self, x):
        # x[N,C,H,W]
        mean = torch.mean(x,dim=[0,2,3], keepdim=True)
        var = torch.mean(x**2, dim=[0,2,3], keepdim=True) - mean**2
        x = (x - mean) / (torch.sqrt(var + self.eps))
        x = x * self.scale.reshape(-1,1,1) + self.shift.reshape(-1,1,1)
        return x
MBN = MyBN(2)
print(MBN(a))

在这里插入图片描述

2. Layer Normalization

Ba J L, Kiros J R, Hinton G E. Layer normalization[J]. arXiv preprint arXiv:1607.06450, 2016.

batch size过小的情况下,BN的效果往往不那么理想,并且很难应用于NLP的任务,因为NLP的句子末尾通常会有填充的空白token,因此batch中每个样本在末尾相对应的特征做BN完全没有意义。为了解决这一问题,LN被提出,原理很简单,就是每个样本自己做标准化即可。
在这里插入图片描述

pytorch简洁代码实现:

LN = nn.LayerNorm(normalized_shape=10)
a = torch.tensor([[[1,2,3,4,5,6,7,8,9,10],
                  [0,0,1,1,2,2,3,3,4,4]]],dtype=torch.float32)
print(LN(a))

在这里插入图片描述
自定义LN:

class MyLN(nn.Module):

    def __init__(self, normalized_shape):
        super(MyLN, self).__init__()
        self.normalized_shape = normalized_shape
        self.scale = nn.Parameter(torch.ones(normalized_shape))
        self.shift = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = 1e-5

    def forward(self, x):
        if isinstance(self.normalized_shape, list):
            dim = [-(i+1) for i in range(len(self.normalized_shape))]
        else:
            dim = -1
        mean = torch.mean(x, dim=dim, keepdim=True)
        var = torch.mean(x**2, dim=dim, keepdim=True) - mean**2
        x = (x - mean) / (torch.sqrt(var + self.eps))
        x = x * self.scale + self.shift
        return x

猜你喜欢

转载自blog.csdn.net/loki2018/article/details/125221692