本人kaggle分享链接:https://www.kaggle.com/c/bengaliai-cv19/discussion/128665
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class FocalLoss(nn.Module):
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
#print(class_mask)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.to(device)
alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1,1)
log_p = probs.log()
#print('probs size= {}'.format(probs.size()))
#print(probs)
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
#print('-----bacth_loss------')
#print(batch_loss)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
F1 = FocalLoss(168)
F2 = FocalLoss(11)
F3 = FocalLoss(7)
def ohem_loss( rate, cls_pred, cls_target ):
batch_size = cls_pred.size(0)
# ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
ohem_cls_loss = F1(cls_pred, cls_target)
sorted_ohem_loss, idx = torch.sort(ohem_cls_loss, descending=True)
keep_num = min(sorted_ohem_loss.size()[0], int(batch_size*rate) )
if keep_num < sorted_ohem_loss.size()[0]:
keep_idx_cuda = idx[:keep_num]
ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
cls_loss = ohem_cls_loss.sum() / keep_num
return cls_loss