classFocalLoss(nn.Module):# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)def__init__(self, loss_fcn, gamma=1.5, alpha=0.25):super().__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction ='none'# required to apply FL to each elementdefforward(self, pred, true):
loss = self.loss_fcn(pred, true)# p_t = torch.exp(-loss)# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
pred_prob = torch.sigmoid(pred)# prob from logits
p_t = true * pred_prob +(1- true)*(1- pred_prob)
alpha_factor = true * self.alpha +(1- true)*(1- self.alpha)
modulating_factor =(1.0- p_t)** self.gamma
loss *= alpha_factor * modulating_factor
if self.reduction =='mean':return loss.mean()elif self.reduction =='sum':return loss.sum()else:# 'none'return loss
DiceLoss
classDiceLoss(nn.Module):def__init__(self, n_classes):super(DiceLoss, self).__init__()
self.n_classes = n_classes
def_one_hot_encoder(self, input_tensor):
tensor_list =[]for i inrange(self.n_classes):
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)return output_tensor.float()def_dice_loss(self, score, target):
target = target.float()
smooth =1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss =(2* intersect + smooth)/(z_sum + y_sum + smooth)
loss =1- loss
return loss
defforward(self, inputs, target, weight=None, softmax=False):if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)if weight isNone:
weight =[1]* self.n_classes
assert inputs.size()== target.size(),'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
class_wise_dice =[]
loss =0.0for i inrange(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0- dice.item())
loss += dice * weight[i]return loss / self.n_classes
Lovasz_softmax损失
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from itertools import filterfalse as ifilterfalse
# https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.pydeflovasz_grad(gt_sorted):"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p =len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts +(1- gt_sorted).float().cumsum(0)
jaccard =1.- intersection / union
if p >1:# cover 1-pixel case
jaccard[1:p]= jaccard[1:p]- jaccard[0:-1]return jaccard
defiou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):"""
IoU for foreground class
binary: 1 foreground, 0 background
"""ifnot per_image:
preds, labels =(preds,),(labels,)
ious =[]for pred, label inzip(preds, labels):
intersection =((label ==1)&(pred ==1)).sum()
union =((label ==1)|((pred ==1)&(label != ignore))).sum()ifnot union:
iou = EMPTY
else:
iou =float(intersection)/float(union)
ious.append(iou)
iou = mean(ious)# mean accross images if per_imagereturn100* iou
defiou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):"""
Array of IoU for each (non ignored) class
"""ifnot per_image:
preds, labels =(preds,),(labels,)
ious =[]for pred, label inzip(preds, labels):
iou =[]for i inrange(C):if i != ignore:# The ignored label is sometimes among predicted classes (ENet - CityScapes)
intersection =((label == i)&(pred == i)).sum()
union =((label == i)|((pred == i)&(label != ignore))).sum()ifnot union:
iou.append(EMPTY)else:
iou.append(float(intersection)/float(union))
ious.append(iou)
ious =[mean(iou)for iou inzip(*ious)]# mean accross images if per_imagereturn100* np.array(ious)# --------------------------- BINARY LOSSES ---------------------------deflovasz_hinge(logits, labels, per_image=True, ignore=None):"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""if per_image:
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))for log, lab inzip(logits, labels))else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))return loss
deflovasz_hinge_flat(logits, labels):"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""iflen(labels)==0:# only void pixels, the gradients should be 0return logits.sum()*0.
signs =2.* labels.float()-1.
errors =(1.- logits * Variable(signs))
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), Variable(grad))return loss
defflatten_binary_scores(scores, labels, ignore=None):"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)if ignore isNone:return scores, labels
valid =(labels != ignore)
vscores = scores[valid]
vlabels = labels[valid]return vscores, vlabels
classStableBCELoss(torch.nn.modules.Module):def__init__(self):super(StableBCELoss, self).__init__()defforward(self,input, target):
neg_abs =-input.abs()
loss =input.clamp(min=0)-input* target +(1+ neg_abs.exp()).log()return loss.mean()defbinary_xloss(logits, labels, ignore=None):"""
Binary Cross entropy loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
ignore: void class id
"""
logits, labels = flatten_binary_scores(logits, labels, ignore)
loss = StableBCELoss()(logits, Variable(labels.float()))return loss
# --------------------------- MULTICLASS LOSSES ---------------------------deflovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""if per_image:
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)for prob, lab inzip(probas, labels))else:
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)return loss
deflovasz_softmax_flat(probas, labels, classes='present'):"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""if probas.numel()==0:# only void pixels, the gradients should be 0return probas *0.#获取类别数
C = probas.size(1)
losses =[]
class_to_sum =list(range(C))if classes in['all','present']else classes
for c in class_to_sum:
fg =(labels == c).float()# foreground for class cif(classes is'present'and fg.sum()==0):continueif C ==1:iflen(classes)>1:raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probas[:,0]else:
class_pred = probas[:, c]
errors =(Variable(fg)- class_pred).abs()
errors_sorted, perm = torch.sort(errors,0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))return mean(losses)defflatten_probas(probas, labels, ignore=None):"""
Flattens predictions in the batch
"""if probas.dim()==3:# assumes output of a sigmoid layer
B, H, W = probas.size()
probas = probas.view(B,1, H, W)
B, C, H, W = probas.size()
probas = probas.permute(0,2,3,1).contiguous().view(-1, C)# B * H * W, C = P, C
labels = labels.view(-1)if ignore isNone:return probas, labels
valid =(labels != ignore)
vprobas = probas[valid.nonzero().squeeze()]
vlabels = labels[valid]return vprobas, vlabels
defxloss(logits, labels, ignore=None):"""
Cross entropy loss
"""return F.cross_entropy(logits, Variable(labels), ignore_index=255)# --------------------------- HELPER FUNCTIONS ---------------------------defisnan(x):return x != x
defmean(l, ignore_nan=False, empty=0):"""
nanmean compatible with generators.
"""
l =iter(l)if ignore_nan:
l = ifilterfalse(isnan, l)try:
n =1
acc =next(l)except StopIteration:if empty =='raise':raise ValueError('Empty mean')return empty
for n, v inenumerate(l,2):
acc += v
if n ==1:return acc
return acc / n