批量标准化(batch normalization简称BN)主要是为了克服当神经网络层数加深而导致难以训练而诞生的。当深度神经网络随着网络深度加深,训练起来会越来越困难,收敛速度会很慢,还会产生梯度消失问题(vanishing gradient problem)。
在统计机器学习领域中有一个ICS(Internal Covariate Shift)理论:源域(source domain)和目标域(target domain)的数据分布是一致的。也就是训练数据和测试数据满足相同的分布,这是通过训练数据获得的模型在测试数据上有一个好的效果的保证。
Covariate Shift是指训练数据的样本和测试数据的样本分布不一致时,训练获取的模型无法很好的泛化。它是分布不一致假设之下的一个分支问题,也就是指源域和目标域的条件概率是一致的,但是其边缘概率不同。对于神经网络而言,神经网络的各层输出,在经过了层内操作后,各层输出分布会随着输入分布的变化而变化,而且差异会随着网络的深度增加而加大,但是每一层随指向的样本标记是不会改变的。
解决Covariate Shift问题可以通过对训练样本和测试样本的比例对训练样本做一个矫正,通过批量标准化来标准化某些层或所有层的输入,从而固定每层输入信号的均值与方差。
一、批量标准化的实现
批量标准化是在激活函数之前,对z=wx+b做标准化,使得输出结果满足标准的正态分布,即均值为0,方差为1。让每一层的输入有一个稳定的分布便于网络的训练。
二、批量标准化的优点
1、加大探索的步长,加快模型收敛的速度
2、更容易跳出局部最小值
3、破坏原来的数据分布,在一定程度上可以缓解过拟合。
当遇到神经网络收敛速度很慢或梯度爆炸等无法训练的情况时,可以尝试使用批量标准化来解决问题。
三、TensorFlow的批量标准化实例
1、tf.nn.moments(x,axes,shift=None,name=None,keep_dims=False)
函数介绍:计算x的均值和方差
参数介绍:
- x:需要计算均值和方差的tensor
- axes:指定求解x某个维度上的均值和方差,如果x是一维tensor,则axes=[0]
- name:用于计算均值和方差操作的名称
- keep_dims:是否产生与输入相同相同维度的结果
z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)
#计算z的均值和方差
#计算列的均值和方差
z_mean_col,z_var_col = tf.nn.moments(z,axes=[0])
#[1.5 1.5 1.5 1.5 1.5] [0.25 0.25 0.25 0.25 0.25]
#计算行的均值和方差
z_mean_row,z_var_row = tf.nn.moments(z,axes=[1])
#等价于axes=[-1],-1表示最后一维
#[1. 2.] [0. 0.]
#计算整个数组的均值和方差
z_mean,z_var = tf.nn.moments(z,axes=[0,1])
#1.5 0.25
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(z))
print(sess.run(z_mean_col),sess.run(z_var_col))
print(sess.run(z_mean_row),sess.run(z_var_row))
print(sess.run(z_mean),sess.run(z_var))
2、tf.nn.batch_normalization(x,mean,variance,offset,scale,variance_epsilon,name=None)
函数介绍:计算batch normalization
参数介绍:
- x:输入的tensor,具有任意的维度
- mean:输入tensor的均值
- variance:输入tensor的方差
- offset:偏置tensor,初始化为1
- scale:比例tensor,初始化为0
- variance_epsilon:一个接近于0的数,避免除以0
z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)
#计算z的均值和方差
z_mean,z_var = tf.nn.moments(z,axes=[0,1])
scale = tf.Variable(tf.ones([2,5]))
shift = tf.Variable(tf.zeros([2,5]))
#计算batch normalization
z_bath_norm = tf.nn.batch_normalization(z,z_mean,z_var,shift,scale,variance_epsilon=0.001)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(z))
print(sess.run(z_bath_norm))