TensorFlow的batch_normalization

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_29957455/article/details/81806271

批量标准化(batch normalization简称BN)主要是为了克服当神经网络层数加深而导致难以训练而诞生的。当深度神经网络随着网络深度加深,训练起来会越来越困难收敛速度会很慢,还会产生梯度消失问题(vanishing gradient problem)。

在统计机器学习领域中有一个ICS(Internal Covariate Shift)理论:源域(source domain)和目标域(target domain)的数据分布是一致的。也就是训练数据和测试数据满足相同的分布,这是通过训练数据获得的模型在测试数据上有一个好的效果的保证。

Covariate Shift是指训练数据的样本和测试数据的样本分布不一致时,训练获取的模型无法很好的泛化。它是分布不一致假设之下的一个分支问题,也就是指源域和目标域的条件概率是一致的,但是其边缘概率不同。对于神经网络而言,神经网络的各层输出,在经过了层内操作后,各层输出分布会随着输入分布的变化而变化,而且差异会随着网络的深度增加而加大,但是每一层随指向的样本标记是不会改变的。

解决Covariate Shift问题可以通过对训练样本和测试样本的比例对训练样本做一个矫正,通过批量标准化来标准化某些层或所有层的输入,从而固定每层输入信号的均值与方差。

一、批量标准化的实现

批量标准化是在激活函数之前,对z=wx+b做标准化,使得输出结果满足标准的正态分布,即均值为0,方差为1。让每一层的输入有一个稳定的分布便于网络的训练。

二、批量标准化的优点

1、加大探索的步长,加快模型收敛的速度

2、更容易跳出局部最小值

3、破坏原来的数据分布,在一定程度上可以缓解过拟合。

当遇到神经网络收敛速度很慢或梯度爆炸等无法训练的情况时,可以尝试使用批量标准化来解决问题。

三、TensorFlow的批量标准化实例

1、tf.nn.moments(x,axes,shift=None,name=None,keep_dims=False)

扫描二维码关注公众号,回复: 3992160 查看本文章

函数介绍:计算x的均值和方差

参数介绍:

  • x:需要计算均值和方差的tensor
  • axes:指定求解x某个维度上的均值和方差,如果x是一维tensor,则axes=[0]
  • name:用于计算均值和方差操作的名称
  • keep_dims:是否产生与输入相同相同维度的结果
    z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)
    #计算z的均值和方差
    #计算列的均值和方差
    z_mean_col,z_var_col = tf.nn.moments(z,axes=[0])
    #[1.5 1.5 1.5 1.5 1.5] [0.25 0.25 0.25 0.25 0.25]
    #计算行的均值和方差
    z_mean_row,z_var_row = tf.nn.moments(z,axes=[1])
    #等价于axes=[-1],-1表示最后一维
    #[1. 2.] [0. 0.]
    #计算整个数组的均值和方差
    z_mean,z_var = tf.nn.moments(z,axes=[0,1])
    #1.5 0.25
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(z))
    print(sess.run(z_mean_col),sess.run(z_var_col))
    print(sess.run(z_mean_row),sess.run(z_var_row))
    print(sess.run(z_mean),sess.run(z_var))

2、tf.nn.batch_normalization(x,mean,variance,offset,scale,variance_epsilon,name=None)

函数介绍:计算batch normalization

参数介绍:

  • x:输入的tensor,具有任意的维度
  • mean:输入tensor的均值
  • variance:输入tensor的方差
  • offset:偏置tensor,初始化为1
  • scale:比例tensor,初始化为0
  • variance_epsilon:一个接近于0的数,避免除以0
    z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)
    #计算z的均值和方差
    z_mean,z_var = tf.nn.moments(z,axes=[0,1])
    scale = tf.Variable(tf.ones([2,5]))
    shift = tf.Variable(tf.zeros([2,5]))
    #计算batch normalization
    z_bath_norm = tf.nn.batch_normalization(z,z_mean,z_var,shift,scale,variance_epsilon=0.001)
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(z))
    print(sess.run(z_bath_norm))

猜你喜欢

转载自blog.csdn.net/sinat_29957455/article/details/81806271
今日推荐