聊一聊pytorch的BN module

BN1块在深度学习中可以说是经常用的一个工具,但是对其细节学深学透,还真的是需要钻研下。

1.BN简介

1.1 关于Internal Covariate Shift

Internal Covariate Shift(ICS)定义是在paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift中给出的:

Training Deep Neural Networks is complicated by the fact that the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change.This slows down the training by requiring lower learning rates and careful parameter initialization, and makes it notoriously hard to train models with saturating nonlinearities.
We refer to this phenomenon as internal covariate shift, and address the problem by normalizing layer inputs.

大意是在训练过程中,各层输入量的分布,会随着它之前那些层参数量的变化而变化,这会导致深度神经网络训练起来非常复杂。由于需要更小的学习率及仔细的参数初始化,训练起来就很慢;同时由于饱和非线性的存在,造成模型训练起来更加困难。这就是Internal Covariate Shift(ICS)现象,可以通过各层输入量的归一化来解决此问题。

1.2 BN公式

BN公式:
y = x − E [ x ] V a r [ x ] + ϵ ∗ γ + β \qquad y = \cfrac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y=Var[x]+ϵ xE[x]γ+β

2.BN transform计算

权重中BN相关参数与其公式是如何对应的?

在eval环节,一个tensor是如何利用BN进行变换的?

这两个问题可能是我们在具体使用pytorch及部署过程中直接关注的。

结合下面这个例子,我们会清晰地找到上述问题的答案。

import torch
import torch.nn as nn
m=nn.BatchNorm2d(3)

input = torch.randn(2, 3, 5, 5)
m.train()
y=m(input)
print(m.state_dict())
#OrderedDict([('weight', tensor([1., 1., 1.])), 
#('bias', tensor([0., 0., 0.])), 
#('running_mean', tensor([-0.0109, -0.0089,  0.0085])), 
#('running_var', tensor([1.0263, 1.0051, 0.9726])), 
#('num_batches_tracked', tensor(1))])

x=torch.ones(1, 3, 1, 1)
m.eval()
with torch.no_grad():
    print(m(x))  
    
#tensor([[[[0.9978]],

#         [[1.0064]],

#         [[1.0054]]]])

训练阶段,使用的输入变量为torch.randn(2, 3, 5, 5),eval阶段使用的输入变量为torch.ones(1, 3, 1, 1),可见BN针对的是各channel进行的。

m.state_dict()与BN公式中各项对应关系为:

m.state_dict() BN
weight γ
bias β
running_mean μ
running_var Var

x[0,0,0,0]经过BN输出0.9978的计算过程(当前 γ \gamma γ为1, β \beta β为0,忽略 ϵ \epsilon ϵ)为:

1 − ( − 0.0109 ) 1.0263 ≈ 0.9978 \qquad \qquad \cfrac{1-(-0.0109)}{ \sqrt{1.0263} }\approx0.9978 1.0263 1(0.0109)0.9978

3.BN 各系数(μ、Var等)的计算及更新过程

再次以具体例子为证,眼见为实。

import torch
import torch.nn as nn
m=nn.BatchNorm2d(3)

为了看下BN块的初始化参数,将其打印出来(使用的是jupyter notebook):

m.state_dict()
#OrderedDict([('weight', tensor([1., 1., 1.])),
#             ('bias', tensor([0., 0., 0.])),
#             ('running_mean', tensor([0., 0., 0.])),
#             ('running_var', tensor([1., 1., 1.])),
#             ('num_batches_tracked', tensor(0))])

可以看到,

m.state_dict() BN 初始化值(省略维度)
weight γ 1
bias β 0
running_mean μ 1
running_var Var 0

训练一次之后,查看BN的各参数值:

input = torch.randn(2, 3, 5, 5)
m.train()
y=m(input)
print(m.state_dict())

#OrderedDict([
#('weight', tensor([1., 1., 1.])), 
#('bias', tensor([0., 0., 0.])), 
#('running_mean', tensor([0.0156, 0.0036, 0.0010])), 
#('running_var', tensor([1.0216, 1.0215, 1.0002])), 
#('num_batches_tracked', tensor(1))])

3.1 running_mean(μ值)的更新

channel 0,input[:,0,:,:]为例,

先打印出其均值:

input[:,0,:,:].mean()
#tensor(0.1556)

对应running_mean中的值为0.0156,这是因为在计算时考虑了momentum,见参考文件3:pytorch 官方介绍:

x ^ new = ( 1 − momentum ) × x ^ + momentum × x t \qquad \qquad\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t x^new=(1momentum)×x^+momentum×xt

momentum默认取0.1,因此有:

0.9 ∗ 0 + 0.1 ∗ 0.1556 ≈ 0.0156 \qquad \qquad0.9*0+0.1*0.1556\approx0.0156 0.90+0.10.15560.0156

3.2 running_var(var值)的更新

类似的,先打印出其var值:

input[:,0,:,:].var()
#tensor(1.2162)

对应running_var中的值为1.0216,同样是源于momentum的贡献。

0.9 ∗ 1 + 0.1 ∗ 1.2162 ≈ 1.0216 \qquad \qquad0.9*1+0.1*1.2162\approx1.0216 0.91+0.11.21621.0216

再次训练,由于使用的是同样的变量,input[:,0,:,:].var()依然还是tensor(1.2162)

更新后的running_var(var值)为:

0.9 ∗ 1.0216 + 0.1 ∗ 1.2162 ≈ 1.0411 \qquad \qquad0.9*1.0216+0.1*1.2162\approx1.0411 0.91.0216+0.11.21621.0411

3.3 其他参数(γ、β)的更新

如果affine=True,γ、β这两个参数在loss反向传播过程中更新。

4.来自大拿的实现

李沐老师团队参考文献4 d2l.ai Batch Normalization:Implementation from Scratch 以及参考文献6给出了如图三个主流框架BN的实现方式,对理解BN层的实现确实大有益处。

在这里插入图片描述

其中的pytorch实现粘贴如下:

import torch
from torch import nn
from d2l import torch as d2l

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # Use `is_grad_enabled` to determine whether the current mode is training
    # mode or prediction mode
    if not torch.is_grad_enabled():
        # If it is prediction mode, directly use the mean and variance
        # obtained by moving average
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # When using a fully-connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean)**2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of `X`, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean)**2).mean(dim=(0, 2, 3), keepdim=True)
        # In training mode, the current mean and variance are used for the
        # standardization
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance using moving average
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # Scale and shift
    return Y, moving_mean.data, moving_var.data

class BatchNorm(nn.Module):
    # `num_features`: the number of outputs for a fully-connected layer
    # or the number of output channels for a convolutional layer. `num_dims`:
    # 2 for a fully-connected layer and 4 for a convolutional layer
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # The variables that are not model parameters are initialized to 0 and 1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # If `X` is not on the main memory, copy `moving_mean` and
        # `moving_var` to the device where `X` is located
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # Save the updated `moving_mean` and `moving_var`
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean, self.moving_var,
            eps=1e-5, momentum=0.9)
        return Y

参考文献

1.详解深度学习中的 Normalization,不只是 BN
2.【PyTorch学习笔记】19:Batch Normalization
3.pytorch BATCHNORM2 官方介绍
4.http://d2l.ai/chapter_convolutional-modern/batch-norm.html
5.知乎讨论:深度学习中 Batch Normalization为什么效果好?
6.d2l-pytorch-colab batch-norm


  1. BN:batch normalization ↩︎

Guess you like

Origin blog.csdn.net/WANGWUSHAN/article/details/116987974