BN实现

这样应该是最接近我对论文的理解写出的bn代码,如果有问题,欢迎指正。


def batch_norm(x, n_out,train, eps=1e-05, decay=0.99,affine=True, name=None):
    with tf.variable_scope(name, default_name='BatchNorm2d'):
      moving_mean = tf.get_variable('mean', [n_out],
                                      initializer=tf.zeros_initializer,
                                      trainable=False)
      moving_variance = tf.get_variable('variance', [n_out],
                                          initializer=tf.ones_initializer,
                                          trainable=False)

      train=tf.convert_to_tensor(train)

      def mean_var_with_update():
        mean, variance = tf.nn.moments(x, [0,1,2], name='moments')
        # 计算train时的moving average用于inference。
        from tensorflow.python.training.moving_averages import assign_moving_average
        with tf.control_dependencies([assign_moving_average(moving_mean, mean, decay),
                                  assign_moving_average(moving_variance, variance, decay)]):
            return tf.identity(mean), tf.identity(variance)
      #train=True时返回batch——mean/var,
      # Train=False时返回moving_mean和moving_variance,这个已经在train的时候更新过了。
      mean, variance = tf.cond(train, mean_var_with_update, lambda: (moving_mean, moving_variance))
      if affine:
            beta = tf.Variable(tf.constant(0.0, shape=[n_out]),
                               name='beta', trainable=True)
            gamma = tf.Variable(tf.constant(1.0, shape=[n_out]),
                                name='gamma', trainable=True)
            x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, eps)
      else:
            x = tf.nn.batch_normalization(x, mean, variance, None, None, eps)
      return x

猜你喜欢

转载自blog.csdn.net/m0_37561765/article/details/79734990
BN