[pytorch] 图像识别之focal loss (+ohem)

本人kaggle分享链接:https://www.kaggle.com/c/bengaliai-cv19/discussion/128665

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class FocalLoss(nn.Module):
   def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
       super(FocalLoss, self).__init__()
       if alpha is None:
           self.alpha = Variable(torch.ones(class_num, 1))
       else:
           if isinstance(alpha, Variable):
               self.alpha = alpha
           else:
               self.alpha = Variable(alpha)
       self.gamma = gamma
       self.class_num = class_num
       self.size_average = size_average
   def forward(self, inputs, targets):
       N = inputs.size(0)
       C = inputs.size(1)
       P = F.softmax(inputs)
       class_mask = inputs.data.new(N, C).fill_(0)
       class_mask = Variable(class_mask)
       ids = targets.view(-1, 1)
       class_mask.scatter_(1, ids.data, 1.)
       #print(class_mask)
       if inputs.is_cuda and not self.alpha.is_cuda:
           self.alpha = self.alpha.to(device)
       alpha = self.alpha[ids.data.view(-1)]
       probs = (P*class_mask).sum(1).view(-1,1)
       log_p = probs.log()
       #print('probs size= {}'.format(probs.size()))
       #print(probs)
       batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
       #print('-----bacth_loss------')
       #print(batch_loss)
       if self.size_average:
           loss = batch_loss.mean()
       else:
           loss = batch_loss.sum()
       return loss
F1 = FocalLoss(168)
F2 = FocalLoss(11)
F3 = FocalLoss(7)
def ohem_loss( rate, cls_pred, cls_target ):
   batch_size = cls_pred.size(0) 
   # ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
   ohem_cls_loss = F1(cls_pred, cls_target)
   sorted_ohem_loss, idx = torch.sort(ohem_cls_loss, descending=True)
   keep_num = min(sorted_ohem_loss.size()[0], int(batch_size*rate) )
   if keep_num < sorted_ohem_loss.size()[0]:
       keep_idx_cuda = idx[:keep_num]
       ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
   cls_loss = ohem_cls_loss.sum() / keep_num
   return cls_loss
发布了342 篇原创文章 · 获赞 794 · 访问量 178万+

猜你喜欢

转载自blog.csdn.net/u014365862/article/details/104216192