Batch Normalization 批标准化及其相关数学原理和推导

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/diyoosjtu/article/details/89220537

数据经过一层层网络之后,输出的数据分布会发生变化,此现象称为Internal Covariate Shift,会给下一层的网络学习带来困难:

  1. 整个网络的学习速度较慢;
  2. 数据的分布容易陷入激活函数(如sigmoid,tanh)的梯度饱和区,减慢了网络的收敛速度。

直接对每一层做归一化是不合理的

如果将每一层的输出都归一化为标准正态分布,均值为0,方差为1,会导致网络完全学习不到输入数据的特征,因为所有的特征都被归一化了。

底层网络学习到的参数信息被丢掉了,降低了网络的数据表达能力。

Batch Normalization的步骤 1

  1. 求出整个batch数据的均值:
    μ b = 1 m i = 1 m x i \mu_b=\frac{1}{m}\sum_{i=1}^mx_i

  2. 求出整个batch数据的方差:
    σ b 2 = 1 m i = 1 m ( x i μ b ) 2 \sigma^2_b = \frac{1}{m}\sum_{i=1}^m(x_i-\mu_b)^2

  3. 将数据归一化为标准正态分布:
    x ^ i = x i μ b σ b 2 + ε \hat{x}_i = \frac{x_i-\mu_b}{\sqrt{\sigma^2_b+\varepsilon}}
    其中的 ε \varepsilon 是为了防止方差为0。

  4. 引入平移缩放参数,得到最终的归一化结果:
    y i = γ x ^ i + β y_i = \gamma\hat{x}_i+\beta

测试阶段使用Unbiased variance estimate 2 来进行总体方差的无偏估计

Question: Is the estimate of the population variance that arises in this way using the sample mean always smaller than what we would get if we used the population mean?
Answer: yes except when the sample mean happens to be the same as the population mean.

We are seeking the sum of squares of distances from the population mean, but end up calculating the sum of squares of differences from the sample mean, which, as will be seen, is the number that minimizes that sum of squares of distances. So unless the sample happens to have the same mean as the population, this estimate will always underestimate the sum of squared differences from the population mean.

证明过程如下:

设样本的平均值是 μ s \mu_s ,设总体的平均值是 μ \mu ,则
(1) μ s = 1 m i = 1 m x i \mu_s = \frac{1}{m}\sum_{i=1}^mx_i \tag{1}

a i = x i μ s a_i = x_i - \mu_s b = μ s μ b = \mu_s - \mu ,则
(2) ( x i μ ) 2 = ( x i μ s ) + ( μ s μ ) 2 = a i + b 2 = a i 2 + 2 a i b + b 2 (x_i - \mu)^2 = |(x_i - \mu_s) + (\mu_s - \mu)|^2 = |a_i+b|^2 = a^2_i+2a_ib+b^2 \tag{2}
将样本中的所有元素都用式(2)进行转换,得到:
(3) i = 1 m ( x i μ ) 2 = i = 1 m a i 2 + i = 1 m b 2 + 2 i = 1 m a i b = i = 1 m a i 2 + i = 1 m b 2 + 2 b i = 1 m a i = i = 1 m a i 2 + i = 1 m b 2 = i = 1 m ( x i μ s ) 2 + i = 1 m b 2 i = 1 m ( x i μ s ) 2 \begin{aligned} \sum_{i=1}^m(x_i - \mu)^2 &= \sum_{i=1}^ma_i^2 + \sum_{i=1}^mb^2 + 2 \sum_{i=1}^ma_ib\\ &= \sum_{i=1}^ma_i^2 + \sum_{i=1}^mb^2 + 2 b\sum_{i=1}^ma_i\\ &= \sum_{i=1}^ma_i^2 + \sum_{i=1}^mb^2\\ &= \sum_{i=1}^m(x_i-\mu_s)^2 + \sum_{i=1}^mb^2\\ &\ge \sum_{i=1}^m(x_i-\mu_s)^2 \tag{3} \end{aligned}
化简式(3)用到了下式:
i = 1 m a i = i = 1 m ( x i μ s ) = i = 1 m x i m μ s = 0 \sum_{i=1}^ma_i= \sum_{i=1}^m(x_i-\mu_s) = \sum_{i=1}^mx_i - m\mu_s = 0

由式(3)可知,总体的实际方差总是大于(或等于)使用样本均值估计得到的总体方差,并且只有当样本均值与总体均值正好相等时,等号才成立,对总体方差的估计才是准确的。

使用Bessel’s correction方法对总体方差进行校正

样本均值如式(1)所示。

样本方差:
σ s 2 = 1 m i = 1 m ( x i μ s ) 2 = i = 1 m x i 2 m + ( i = 1 m x i ) 2 m 2 \sigma_s^2 = \frac{1}{m}\sum_{i=1}^m(x_i-\mu_s)^2 = \frac{\sum_{i=1}^mx_i^2}{m} + \frac{\left(\sum_{i=1}^mx_i\right)^2}{m^2}

校正后的总体方差,使用 m 1 m-1 替换公式中的 m m ,得到:
σ 2 = 1 m 1 i = 1 m ( x i μ s ) 2 = i = 1 m x i 2 m 1 + ( i = 1 m x i ) 2 m ( m 1 ) = m m 1 σ s 2 \sigma^2 = \frac{1}{m-1}\sum_{i=1}^m(x_i-\mu_s)^2 = \frac{\sum_{i=1}^mx_i^2}{m-1} + \frac{\left(\sum_{i=1}^mx_i\right)^2}{m(m-1)} = \frac{m}{m-1}\sigma^2_s

扩展阅读:

为什么样本方差(sample variance)的分母是 n-1? - 张英锋的回答 - 知乎
https://www.zhihu.com/question/20099757/answer/658048814

Batch Normalization的优点

  1. BN使得网络中每层输入数据的分布变得相对稳定,可以加快网络的学习速度;
    BN在一定程度上可以把靠近梯度饱和区的数据分布,往中心的线性区拉一点,增大梯度。

  2. BN可以简化调参过程,使得网络更加稳定;
    BN可以抑制参数的微小变化随着网络层数加深而被放大的问题,增强网络对参数取值的适应能力。

  3. BN使得网络可以使用sigmoid和tanh等激活函数,缓解梯度消失问题;

  4. BN有一定的正则化效果。

参考文献


  1. Ioffe, S. and Szegedy, C., 2015. Batch normalization: Accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167. ↩︎

  2. https://en.wikipedia.org/wiki/Bessel's_correction#Proof_of_correctness_-_Alternate_3 ↩︎

猜你喜欢

转载自blog.csdn.net/diyoosjtu/article/details/89220537