版权声明:本文为博主CSDN Rosefun96原创文章。 https://blog.csdn.net/rosefun96/article/details/88241916
1 理论
OHEM就是对每次检测到loss较大的前几个样本进行计算loss,重新训练。
2 实现
def rpn_class_loss_graph(config, rpn_match, rpn_class_logits):
"""RPN anchor classifier loss.
rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
-1=negative, 0=neutral anchor.
rpn_class_logits: [batch, anchors, 2]. RPN classifier logits for FG/BG.
modified by YY: to implement OHEM
"""
# Squeeze last dim to simplify
rpn_match = tf.squeeze(rpn_match, -1)
# Get anchor classes. Convert the -1/+1 match to 0/1 values.
anchor_class = K.cast(K.equal(rpn_match, 1), tf.int32)
# Positive and Negative anchors contribute to the loss,
# but neutral anchors (match value = 0) don't.
indices = tf.where(K.not_equal(rpn_match, 0))
# Pick rows that contribute to the loss and filter out the rest.
rpn_class_logits = tf.gather_nd(rpn_class_logits, indices)
anchor_class = tf.gather_nd(anchor_class, indices)
# Cross entropy loss
ce_loss = K.sparse_categorical_crossentropy(target=anchor_class,
output=rpn_class_logits,
from_logits=True)
n_selected = tf.cast(config.RPN_TRAIN_ANCHORS_PER_IMAGE, tf.int32)
vals, _ = tf.nn.top_k(ce_loss, k = n_selected)
thresh = vals[-1]
samples = ce_loss >= thresh
# only include samples in loss cal
loss_weight = tf.cast(samples, tf.float32)
loss = K.sum(ce_loss * loss_weight) / K.sum(loss_weight)
#loss = K.switch(tf.size(loss) > 0, K.sum(loss * loss_weight) / K.sum(loss_weight), tf.constant(0.0))
#loss = K.switch(tf.size(loss) > 0, K.mean(loss), tf.constant(0.0))
return loss
其他实例
class topk_crossEntrophy(nn.Module):
def __init__(self, top_k=0.7):
super(topk_crossEntrophy, self).__init__()
self.loss = nn.NLLLoss()
self.top_k = top_k
self.softmax = nn.LogSoftmax()
return
def forward(self, input, target):
softmax_result = self.softmax(input)
loss = Variable(torch.Tensor(1).zero_())
for idx, row in enumerate(softmax_result):
gt = target[idx]
pred = torch.unsqueeze(row, 0)
cost = self.loss(pred, gt)
loss = torch.cat((loss, cost), 0)
loss = loss[1:]
if self.k == 1:
valid_loss = loss
index = torch.topk(loss, int(self.top_k * loss.size()[0]))
valid_loss = loss[index[1]]
return torch.mean(valid_loss)
参考:
- Github ;
- 博客详解OHEM;
- pytorch 社区;