【tensorflow】slim模块中fine-tun中的BatchNormalization的设置

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/shwan_ma/article/details/83502333

tensorflow的BatchNorm 应该是tensorflow中最大的坑之一。大家遇到最多的问题就是在fine-tune的时候,加载一个预模型然后在训练时候发现效果良好,但是在测试的时候直接扑街。

这是因为batch normalization在训练过程中需要去计算整个样本的均值和方差,而在代码实现中,BN则是采取用移动平均(moving average)来求取批均值和批方差来,所以在每一个批度下来,都会对他的mean和var进行更新。所以在使用BN的时候,需要将moving_mean和moving_variance加入到tf.GraphKeys.UPDATE_OPS操作中。

此处以Inception v3的argscope为例:

def inception_v3_arg_scope(weight_decay=0.00004,
                           batch_norm_var_collection='moving_vars',
                           batch_norm_decay=0.9997,
                           batch_norm_epsilon=0.001,
                           updates_collections=ops.GraphKeys.UPDATE_OPS,
                           use_fused_batchnorm=True):
  """Defines the default InceptionV3 arg scope.

  Args:
    weight_decay: The weight decay to use for regularizing the model.
    batch_norm_var_collection: The name of the collection for the batch norm
      variables.
    batch_norm_decay: Decay for batch norm moving average
    batch_norm_epsilon: Small float added to variance to avoid division by zero
    updates_collections: Collections for the update ops of the layer
    use_fused_batchnorm: Enable fused batchnorm.

  Returns:
    An `arg_scope` to use for the inception v3 model.
  """
  batch_norm_params = {
      # Decay for the moving averages.
      'decay': batch_norm_decay,
      # epsilon to prevent 0s in variance.
      'epsilon': batch_norm_epsilon,
      # collection containing update_ops.
      'updates_collections': updates_collections,
      # Use fused batch norm if possible.
      'fused': use_fused_batchnorm,
      # collection containing the moving mean and moving variance.
      'variables_collections': {
          'beta': None,
          'gamma': None,
          'moving_mean': [batch_norm_var_collection],
          'moving_variance': [batch_norm_var_collection],
      }
  }

  # Set weight_decay for weights in Conv and FC layers.
  with arg_scope(
      [layers.conv2d, layers_lib.fully_connected],
      weights_regularizer=regularizers.l2_regularizer(weight_decay)):
    with arg_scope(
        [layers.conv2d],
        weights_initializer=initializers.variance_scaling_initializer(),
        activation_fn=nn_ops.relu,
        normalizer_fn=layers_lib.batch_norm,
        normalizer_params=batch_norm_params) as sc:
      return sc

可以看到moving_mean和moving_variance加入到ops.GraphKeys.UPDATE_OPS, 所以需要对这个集合进行更新

代码示例:

opt = tf.train.AdamOptimizer(learning_rate=lr_v)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies([tf.group(*update_ops)]):
    optimizer = opt.minimize(loss)

上面这段代码表示在求解minimize loss的时候,也需要对BN的参数进行更新。
此时,问题解决

参考:https://blog.csdn.net/qq_25737169/article/details/79616671

猜你喜欢

转载自blog.csdn.net/shwan_ma/article/details/83502333
今日推荐