Article Directory
foreword
In the actual training process of segmentation network tasks, the choice of loss function is particularly important. For semantic segmentation, it is very likely that there is an imbalance between positive and negative samples, or category imbalance. Therefore, choosing an appropriate loss function plays a vital role in model convergence and accurate prediction.
1. Cross entropy loss
M is the number of categories;
yic is an indicative function, indicating which category the element belongs to;
pic is the predicted probability, the predicted probability that the observed sample belongs to category c, and the predicted probability needs to be estimated and calculated in advance;
Disadvantages:
Cross-entropy Loss can be used in most semantic segmentation scenarios, but it has an obvious disadvantage, that is, when only foreground and background are used to segment, when the number of foreground pixels is much smaller than the number of background pixels, ie The number of background elements is much greater than the number of foreground elements, and the components in the loss function of the background elements will dominate, making the model heavily biased towards the background, resulting in poor model training and prediction results.
Similarly, BCEloss also faces this problem, and BCEloss is as follows.
Do a binary classification loss calculation for all N categories.
#二值交叉熵,这里输入要经过sigmoid处理
import torch
import torch.nn as nn
import torch.nn.functional as F
nn.BCELoss(F.sigmoid(input), target)
#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层
nn.CrossEntropyLoss(input, target)
二、Focal loss
He Kaiming's team introduced Focal Loss in the RetinaNet paper to solve the imbalance in the number of difficult and easy samples. Let's review it.
The number of samples and the confidence are penalized, and the loss weight of large samples and the loss weight of high confidence samples are considered to be lower.
class FocalLoss(nn.Module):
"""
copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
Focal_Loss= -1*alpha*(1-pt)*log(pt)
:param num_class:
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param smooth: (float,double) smooth value when cross entropy
:param balance_index: (int) balance class index, should be specific when alpha is float
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
"""
def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
super(FocalLoss, self).__init__()
self.apply_nonlin = apply_nonlin
self.alpha = alpha
self.gamma = gamma
self.balance_index = balance_index
self.smooth = smooth
self.size_average = size_average
if self.smooth is not None:
if self.smooth < 0 or self.smooth > 1.0:
raise ValueError('smooth value should be in [0,1]')
def forward(self, logit, target):
if self.apply_nonlin is not None:
logit = self.apply_nonlin(logit)
num_class = logit.shape[1]
if logit.dim() > 2:
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
logit = logit.view(logit.size(0), logit.size(1), -1)
logit = logit.permute(0, 2, 1).contiguous()
logit = logit.view(-1, logit.size(-1))
target = torch.squeeze(target, 1)
target = target.view(-1, 1)
# print(logit.shape, target.shape)
#
alpha = self.alpha
if alpha is None:
alpha = torch.ones(num_class, 1)
elif isinstance(alpha, (list, np.ndarray)):
assert len(alpha) == num_class
alpha = torch.FloatTensor(alpha).view(num_class, 1)
alpha = alpha / alpha.sum()
elif isinstance(alpha, float):
alpha = torch.ones(num_class, 1)
alpha = alpha * (1 - self.alpha)
alpha[self.balance_index] = self.alpha
else:
raise TypeError('Not support alpha type')
if alpha.device != logit.device:
alpha = alpha.to(logit.device)
idx = target.cpu().long()
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
if one_hot_key.device != logit.device:
one_hot_key = one_hot_key.to(logit.device)
if self.smooth:
one_hot_key = torch.clamp(
one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
pt = (one_hot_key * logit).sum(1) + self.smooth
logpt = pt.log()
gamma = self.gamma
alpha = alpha[idx]
alpha = torch.squeeze(alpha)
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss
1. Dice loss function
Set similarity measure function. It is usually used to calculate the similarity between two samples and belongs to metric learning. X is the real target mask, and Y is the predicted target mask. We always hope that the intersection of X and Y is as large as possible, and the proportion is as large as possible, but the loss needs to gradually decrease, so add a negative sign in front of the ratio.
It can alleviate the negative impact caused by the imbalance of the foreground and background (area) in the sample. The imbalance of the foreground and background means that most areas in the image do not contain the target, and only a small part of the area contains the target. Dice Loss training pays more attention to the mining of the foreground area, which guarantees a lower FN, but there will be a loss saturation problem, while CE Loss calculates the loss of each pixel equally. Therefore, using Dice Loss alone often cannot achieve good results, and it needs to be used in combination, such as Dice Loss+CE Loss or Dice Loss+Focal Loss.
Here is the link to the original text: https://blog.csdn.net/Mike_honor/article/details/125871091
def dice_loss(prediction, target):
"""Calculating the dice loss
Args:
prediction = predicted image
target = Targeted image
Output:
dice_loss"""
smooth = 1.0
i_flat = prediction.view(-1)
t_flat = target.view(-1)
intersection = (i_flat * t_flat).sum()
return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth))
def calc_loss(prediction, target, bce_weight=0.5):
"""Calculating the loss and metrics
Args:
prediction = predicted image
target = Targeted image
metrics = Metrics printed
bce_weight = 0.5 (default)
Output:
loss : dice loss of the epoch """
bce = F.binary_cross_entropy_with_logits(prediction, target)
prediction = F.sigmoid(prediction)
dice = dice_loss(prediction, target)
loss = bce * bce_weight + dice * (1 - bce_weight)
return loss
1. IOU loss
The loss function is similar to the Dice loss function, which is measured by metric learning. You can try it in the experiment, and it has a miraculous effect in the convergence of small target segmentation!
def SoftIoULoss( pred, target):
# Old One
pred = torch.sigmoid(pred)
smooth = 1
# print("pred.shape: ", pred.shape)
# print("target.shape: ", target.shape)
intersection = pred * target
loss = (intersection.sum() + smooth) / (pred.sum() + target.sum() -intersection.sum() + smooth)
# loss = (intersection.sum(axis=(1, 2, 3)) + smooth) / \
# (pred.sum(axis=(1, 2, 3)) + target.sum(axis=(1, 2, 3))
# - intersection.sum(axis=(1, 2, 3)) + smooth)
loss = 1 - loss.mean()
# loss = (1 - loss).mean()
return loss
一、TverskyLoss
Segmentation tasks also have different emphases. For example, medical segmentation pays more attention to the recall rate (high sensitivity), that is, the real mask is predicted as much as possible, and it does not pay much attention to whether the prediction mask has more predictions. B is the real mask, and A is the predicted mask. |AB| is false positive, |BA| is false negative, alpha and beta can control the trade-off between false positive and false negative. If we focus more on recall, the effect of |BA| is amplified.
Among them, alpha and beta can affect the retrieval rate and accuracy rate. If we want the target to have a higher recall rate, then we can choose a higher beta.
class TverskyLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
square=False):
"""
paper: https://arxiv.org/pdf/1706.05721.pdf
"""
super(TverskyLoss, self).__init__()
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.alpha = 0.3
self.beta = 0.7
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)
if not self.do_bg:
if self.batch_dice:
tversky = tversky[1:]
else:
tversky = tversky[:, 1:]
tversky = tversky.mean()
return -tversky
Summarize
After a series of experiments, it was found that the latter four loss functions are more suitable for small target segmentation network training. But each task is different, if you have plenty of time, you can try them one by one.