pytorch之——Batch Normalization

pytorch之——Batch Normalization


一、Batch Normalization概念

1.Batch Normalization:批标准化

批:一批数据,通常为mini_batch
标准化:0均值,1方差
优点:
(1)可以用更大的学习率,加速模型收敛
(2)可以不用精心设计权值初始化(因为权值初始化也是为了缩放数据尺度,BN有着同样的作用)
(3)可以不用dropout或较小的dropout(论文中实验尝试得到)
(4)可以不用L2或者较小的weight decay(论文中实验尝试得到)
(5)可以不用LRN(local response normalization)(也是一种归一化)

2.基本动机与原理

  神经网络训练的本质是学习数据分布,如果训练数据和与测试数据的分布不同讲大大降低网路的泛化能力,因此我们需要在训练开始前对所有的数据进行归一化处理。
  然而随着网络训练的进行,每个隐藏层的参数变化使得后一层的输入发生变化,从而每一批训练数据的分布也随之改变,致使神经网络在每次迭代中都需要拟合不同的数据分布,增大训练的复杂度以及过拟合的风险。
  批标准化方法时针对每一批数据,在网络的每一层输入之前增加归一化处理(均值为0,标准差为1),将所有的批数据强制在统一的数据分布下。

3.计算方式

在这里插入图片描述
可学习参数gamma和beta的作用:增加模型的可学习能力。是为了使模型自己学习该数据是否需要标准化

二、Pytorch的Batch Normalization 1d/2d/3d实现

1._BatchNorm(基类)

在这里插入图片描述

2.Batch Normalization 1d/2d/3d

在这里插入图片描述
注意事项:pytorch在实现BN的时候,当前时刻的均值和方差也考虑了之前时刻的均值和方差,具体计算方式如上图中的running_mean和running_var。

pre_running_mean为之前时刻的均值。
mean_t为当前t时刻的均值和方差

3.Batch Normalization 1d/2d/3d的输入及计算方式

在这里插入图片描述
说明:BN在计算时,是在每一批数据的每一个特征维度上分别计算一个均值和方差,如上图的讲解

4.pytorch代码实现

bn = orch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) #1d

在这里插入图片描述

bn = torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)#2d

在这里插入图片描述

bn = torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) #3d

在这里插入图片描述

参考

深度之眼pytorch框架班以及pytorch中文文档

猜你喜欢

转载自blog.csdn.net/weixin_43183872/article/details/108293899