理解 Batch Normalization

前几天在面试过程中问到了有关BN层的相关知识,自己虽然对于基本的原理知道,但是更为深层的东西有些模糊,所以结束之后再补充学习一下~

BN

机器学习中一个很重要的假设是独立同分布假设,即数据之间是彼此独立的,而且训练数据和测试数据之间满足同分布的要求。之所以做这样的假设,一方面希望模型可以从不同的特征中进行学习,而是过度依赖某几个特征;另一方面是为了使得在训练集上表现良好的模型具备足够好的泛化性能。如果训练集数据满足的分布和测试集数据满足的分布不同,那么训练好的模型自然不可能在测试集上表现良好。

那么为什么要对输入数据做规范化操作呢?我们可以从传统的机器学习中的规范化操作进行理解,ML中通常都需要对原始数据进行中心化(Zero-Centered)和标准化(Normalization)处理,目的是得到均值为0,标准差为1 的服从标准正态分布的数据。这样做的目的是,通过标准化处理使得不同量纲的特征具有相同的尺度,防止不同量纲的特征对于模型参数训练的影响,同时可以加快模型的收敛。


在这里插入图片描述

同理,深度学习的本质同样是为了学习数据所满足的分布,如果测试数据和训练数据分布不同,那么模型的泛化性能自然很差;另外如果模型训练时每个mini-batch的数据满足的分布都不同,那么模型在每一次迭代时都需要调整参数来适应不同的分布,这将大大的降低模型的收敛速度。

在理解BN为什么有用前,我们需要明白一个问题。Covariate Shift可以译为协变量偏移,它主要用于描述训练集和测试集在分布上具有差异性,从而影响了模型的泛化性和训练速度,通常可以用归一化操作缓解。对于深度学习这种包含很多隐层的网络结构,在训练过程中,因为各层参数不停在变化,所以每个隐层都会面临covariate shift的问题,也就是**在训练过程中,隐层的输入分布老是变来变去,这就是所谓的“Internal Covariate Shift”。

那么能不能让每个隐层节点的激活输入分布固定下来呢?这样就避免了“Internal Covariate Shift”问题了,顺带解决反向传播中梯度消失问题。BN 其实就是在做 feature scaling,而且它的目的也是为了在训练的时候避免这种 Internal Covariate Shift 的问题,只是刚好也解决了 sigmoid 函数梯度消失的问题。

BN的基本思想其实相当直观:因为深层神经网络在做非线性变换前的激活输入值(就是那个 x = W U + B x=WU+B U U 是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值 W U + B WU+B 是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因

而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。

对于每个隐层神经元,把逐渐向非线性函数映射后向取值区间极限饱和区靠拢的输入分布强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。因为梯度一直都能保持比较大的状态,所以很明显对神经网络的参数调整效率比较高,就是变动大,就是说向损失函数最优值迈动的步子大,也就是说收敛地快。BN说到底就是这么个机制,方法很简单,道理很深刻。

如果只是单纯的每次都将激活函数的输入值拉回到非线性函数的线性区间内,那么多层神经网络也只是相当于线性叠加,这样做将大大的削弱模型的非线性拟合能力,即模型自身的学习能力。那么为了平衡这两个问题同时在一定程度上保留模型已经学到的东西,在BN的最后一步还添加了平移和放缩参数 γ \gamma β \beta 。平移和放缩参数的目的是使规范化后的分布不只是完全满足均值为零、方差为1的标准正态分布,而是允许它左偏一点或是右偏一点,从而保持模型的非线性表达能力。

BN层算法描述:


在这里插入图片描述

在这里插入图片描述

其中 γ \gamma β \beta 是平移和缩放因子。它的主要做法为:

  • 直接对输入的每个维度做规范化
  • 在每个mini-batch中计算得到的均值和方差来代替整体训练集的均值和方差。

BatchNorm的推理(Inference)过程

BN在训练的时候可以根据Mini-Batch里的若干训练实例进行激活数值调整,但是在推理(inference)的过程中,很明显输入就只有一个实例,看不到Mini-Batch其它实例,那么这时候怎么对输入做BN呢?因为很明显一个实例是没法算实例集合求出的均值和方差的。这可如何是好?既然没有从Mini-Batch数据里可以得到的统计量,那就想其它办法来获得这个统计量,就是均值和方差。可以用从所有训练实例中获得的统计量来代替Mini-Batch里面m个训练实例获得的均值和方差统计量,因为本来就打算用全局的统计量,只是因为计算量等太大所以才会用Mini-Batch这种简化方式的,那么在推理的时候直接用全局统计量即可。

决定了获得统计量的数据范围,那么接下来的问题是如何获得均值和方差的问题。很简单,因为每次做Mini-Batch训练时,都会有那个Mini-Batch里m个训练实例获得的均值和方差,现在要全局统计量,只要把每个Mini-Batch的均值和方差统计量记住,然后对这些均值和方差求其对应的数学期望即可得出全局统计量

优点:

  • 加快网络的训练和收敛的速度:将每一层的数据都转换为均值为零、方差为1的高斯分布,每层的分布相同则易收敛
  • 防止梯度弥散:使用BN层归一化后,通过将激活值规范化为均值和方差一致的手段使得原本会减小的activation的尺度变大
  • 防止过拟合:在网络的训练中,BN的使用使得一个mini batch中所有样本都被关联在了一起,因此网络不会从某一个训练样本中生成确定的结果,即同样一个样本的输出不再仅仅取决于样本的本身,也取决于跟这个样本同属一个batch的其他样本,而每次网络都是随机取batch,这样就会使得整个网络不会朝这一个方向使劲学习。一定程度上避免了过拟合
  • 使得调参过程更简单

为什么BN层一般用在线性层和卷积层后面,而不是放在非线性单元后?

因为非线性单元的输出分布形状会在训练过程中变化,归一化无法消除它的方差偏移,相反的,全连接和卷积层的输出一般是一个对称、非稀疏的一个分布,更加类似高斯分布,对它们进行归一化会产生更加稳定的分布。


batchnorm原理及代码详解

深度学习中 Batch Normalization为什么效果好?

残差网络解决了什么,为什么有效?

深度学习—BN的理解(一)

深入理解Batch Normalization批标准化
为什么Batch Normalization那么有用?

发布了448 篇原创文章 · 获赞 122 · 访问量 22万+

猜你喜欢

转载自blog.csdn.net/Forlogen/article/details/105358096