batchrenorm代码

class batch_renorm():

        def  __init__(self, n_out,renorm_momentum=0.97, RMAX=1, DMAX=0,epsilon=1e-3):
            self.n_out = n_out
            self.moving_mean = _variable_on_cpu('moving_mean', [self.n_out],
                                                  initializer=tf.zeros_initializer,
                                                  train=False)
            self.moving_variance =  _variable_on_cpu('moving_variance', [self.n_out],
                                                      initializer=tf.ones_initializer,
                                                  train=False)
            self.epsilon=epsilon
            self.RMAX=RMAX
            self.DMAX=DMAX
            self.renorm_momentum=renorm_momentum
        def __call__(self, inputs, train=True):

            beta = tf.Variable(tf.constant(0.0, shape=[self.n_out]),
                                name='beta', trainable=True)
            gamma = tf.Variable(tf.constant(1.0, shape=[self.n_out]),
                                 name='gamma', trainable=True)


            def _batch_norm_training():
                batch_mean, batch_variance = tf.nn.moments(inputs, [0, 1, 2], name='moments')
                # new_mean, new_variance推测应该是batch_mean/variance
                from tensorflow.python.ops import math_ops
                moving_inv = math_ops.rsqrt(self.moving_variance + self.epsilon)
                r = tf.stop_gradient(tf.clip_by_value(tf.sqrt(batch_variance + self.epsilon)*moving_inv,
                                                      1 / self.RMAX,
                                                      self.RMAX))
                d = tf.stop_gradient(tf.clip_by_value((batch_mean - self.moving_mean) * moving_inv,
                                                      -self.DMAX,
                                                      self.DMAX))
                scale = tf.stop_gradient(r, name='renorm_r')
                offset = tf.stop_gradient(d, name='renorm_d')
                if gamma is not None:
                    scale *= gamma
                    offset *= gamma
                if beta is not None:
                    offset += beta
                with tf.control_dependencies([assign_moving_average(self.moving_mean, batch_mean,self.renorm_momentum),
                                              assign_moving_average(self.moving_variance, batch_variance,self.renorm_momentum)]):

                  return tf.nn.batch_normalization(inputs,  batch_mean, batch_variance, offset, scale, self.epsilon)

            def _batch_norm_inference():
                return tf.nn.batch_normalization(
                    inputs,
                    mean=self.moving_mean,
                    variance=self.moving_variance,
                    offset=beta,
                    scale=gamma,
                    variance_epsilon=self.epsilon)
            train=tf.convert_to_tensor(train)
            output=tf.cond(train,_batch_norm_training,_batch_norm_inference)
            return output

猜你喜欢

转载自blog.csdn.net/m0_37561765/article/details/80176908
今日推荐