坑1:BatchNormalization
1号坑陨石坑
-
tf.nn.batch_normalization()
,tf.layers.batch_normalization
和tensorflow.contrib.layers.batch_norm()
,这三个batch normal函数的封装程度逐渐递增。这三个函数会自动将update_ops
添加到tf.GraphKeys.UPDATE_OPS
这个collection
中。 -
tf.keras.layers.BatchNormalization
不会自动将update_ops
添加到tf.GraphKeys.UPDATE_OPS
这个collection
中。所以在 TensorFlow 训练session
中使用tf.keras.layers.BatchNormalization
时,需要手动将keras.BatchNormalization
层的updates
添加到tf.GraphKeys.UPDATE_OPS
这个collection
中。
x = tf.placeholder("float",[None,32,32,3])
bn1 = tf.keras.layers.BatchNormalization()
y = bn1(x, training=True) # 调用后updates属性才会有内容。
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates)
- Batch Normalization 中需要计算移动平均值,所以 BN 中有一些
update_ops
,在训练中需要通过tf.control_dependencies()
来添加对update_ops
的调用:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.train_op = tf.train.AdamOptimizer(learning_rate).minimize(self.loss, global_step=self.global_step)
参考
[1] https://blog.csdn.net/u014061630/article/details/85104491
2号坑小土坑
使用Batch Normalization的卷积神经网络,当在验证阶段,将is_training
设置为False
之后,loss
会爆炸式增长。
1.测试阶段代码:
_, loss, acc = self.sess.run([self.model.train_op, self.model.loss, self.model.acc],
feed_dict={self.model.x: x, self.model.y: y, self.model.is_training: True})
1.验证阶段代码:
loss, acc = self.sess.run([self.model.loss, self.model.acc],
feed_dict={self.model.x: x, self.model.y: y, self.model.is_training: False})
在解决loss爆炸和测试阶段与训练阶段loss和accuracy差异巨大的时候,参考1
效果明显。
tf.contrib.layers.batch_normal
Decay for the moving average. Reasonable values for decay are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower decay value (recommend trying decay=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability.
参考
[1] https://stackoverflow.com/questions/47953242/tensorflow-batch-normalization-tf-contrib-layers-batch-norm
[2] https://arxiv.org/pdf/1711.00489.pdf
坑2:TensorFlow二分类深度网络训练loss=0.9
现象:在训练一个二分类网络的时候,loss一直不变,维持在0.693这个数。
有一种情况,就是你最后一层输出你不小心加了relu激活。这样可能导致输出值是0,如果你用sigmoid那么求出的概率是0.5,-log0.5等于0.69。而且,由于relu输出值为0的时候,relu的导数为0,梯度没办法回传,因此就一直固定在0.69。
参考:
[1] https://www.jianshu.com/p/45c2180cab17
[2] https://www.zhihu.com/question/275774218/answer/385992804