这样应该是最接近我对论文的理解写出的bn代码,如果有问题,欢迎指正。
def batch_norm(x, n_out,train, eps=1e-05, decay=0.99,affine=True, name=None):
with tf.variable_scope(name, default_name='BatchNorm2d'):
moving_mean = tf.get_variable('mean', [n_out],
initializer=tf.zeros_initializer,
trainable=False)
moving_variance = tf.get_variable('variance', [n_out],
initializer=tf.ones_initializer,
trainable=False)
train=tf.convert_to_tensor(train)
def mean_var_with_update():
mean, variance = tf.nn.moments(x, [0,1,2], name='moments')
# 计算train时的moving average用于inference。
from tensorflow.python.training.moving_averages import assign_moving_average
with tf.control_dependencies([assign_moving_average(moving_mean, mean, decay),
assign_moving_average(moving_variance, variance, decay)]):
return tf.identity(mean), tf.identity(variance)
#train=True时返回batch——mean/var,
# Train=False时返回moving_mean和moving_variance,这个已经在train的时候更新过了。
mean, variance = tf.cond(train, mean_var_with_update, lambda: (moving_mean, moving_variance))
if affine:
beta = tf.Variable(tf.constant(0.0, shape=[n_out]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[n_out]),
name='gamma', trainable=True)
x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, eps)
else:
x = tf.nn.batch_normalization(x, mean, variance, None, None, eps)
return x