困难样本挖掘 online hard example mining

版权声明:本文为博主CSDN Rosefun96原创文章。 https://blog.csdn.net/rosefun96/article/details/88241916

1 理论

OHEM就是对每次检测到loss较大的前几个样本进行计算loss,重新训练。

2 实现

def rpn_class_loss_graph(config, rpn_match, rpn_class_logits):
    """RPN anchor classifier loss.
    rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
               -1=negative, 0=neutral anchor.
    rpn_class_logits: [batch, anchors, 2]. RPN classifier logits for FG/BG.
    
    modified by YY: to implement OHEM
    """
    # Squeeze last dim to simplify
    rpn_match = tf.squeeze(rpn_match, -1)
    # Get anchor classes. Convert the -1/+1 match to 0/1 values.
    anchor_class = K.cast(K.equal(rpn_match, 1), tf.int32)
    # Positive and Negative anchors contribute to the loss,
    # but neutral anchors (match value = 0) don't.
    indices = tf.where(K.not_equal(rpn_match, 0))
    # Pick rows that contribute to the loss and filter out the rest.
    rpn_class_logits = tf.gather_nd(rpn_class_logits, indices)
    anchor_class = tf.gather_nd(anchor_class, indices)
    # Cross entropy loss
    ce_loss = K.sparse_categorical_crossentropy(target=anchor_class,
                                             output=rpn_class_logits,
                                             from_logits=True)
    n_selected = tf.cast(config.RPN_TRAIN_ANCHORS_PER_IMAGE, tf.int32)
    vals, _ = tf.nn.top_k(ce_loss, k = n_selected)
    thresh = vals[-1]
    samples = ce_loss >= thresh
    # only include samples in loss cal
    loss_weight = tf.cast(samples, tf.float32)
    loss = K.sum(ce_loss * loss_weight) / K.sum(loss_weight)
    #loss = K.switch(tf.size(loss) > 0, K.sum(loss * loss_weight) / K.sum(loss_weight), tf.constant(0.0))
    #loss = K.switch(tf.size(loss) > 0, K.mean(loss), tf.constant(0.0))
    return loss

其他实例

class topk_crossEntrophy(nn.Module):
	def __init__(self, top_k=0.7):
	    super(topk_crossEntrophy, self).__init__()
	    self.loss = nn.NLLLoss()
	    self.top_k = top_k
	    self.softmax = nn.LogSoftmax()
	    return
	    
	def forward(self, input, target):
	    softmax_result = self.softmax(input)
	    
	    loss = Variable(torch.Tensor(1).zero_())
	    for idx, row in enumerate(softmax_result):
	        gt = target[idx]
	        pred = torch.unsqueeze(row, 0)
	        cost = self.loss(pred, gt)
	        loss = torch.cat((loss, cost), 0)
	        
	    loss = loss[1:]
	    if self.k == 1:
	        valid_loss = loss
	        
	    index = torch.topk(loss, int(self.top_k * loss.size()[0]))
	    valid_loss = loss[index[1]]
	        
	    return torch.mean(valid_loss)

参考:

  1. Github
  2. 博客详解OHEM
  3. pytorch 社区

猜你喜欢

转载自blog.csdn.net/rosefun96/article/details/88241916