Segmentation Network Loss Function Summary! Cross entropy, Focal loss, Dice, iou, TverskyLoss!


foreword

In the actual training process of segmentation network tasks, the choice of loss function is particularly important. For semantic segmentation, it is very likely that there is an imbalance between positive and negative samples, or category imbalance. Therefore, choosing an appropriate loss function plays a vital role in model convergence and accurate prediction.


1. Cross entropy loss

insert image description here
M is the number of categories;
yic is an indicative function, indicating which category the element belongs to;
pic is the predicted probability, the predicted probability that the observed sample belongs to category c, and the predicted probability needs to be estimated and calculated in advance;

Disadvantages:
Cross-entropy Loss can be used in most semantic segmentation scenarios, but it has an obvious disadvantage, that is, when only foreground and background are used to segment, when the number of foreground pixels is much smaller than the number of background pixels, ie The number of background elements is much greater than the number of foreground elements, and the components in the loss function of the background elements will dominate, making the model heavily biased towards the background, resulting in poor model training and prediction results.

Similarly, BCEloss also faces this problem, and BCEloss is as follows.
insert image description here
Do a binary classification loss calculation for all N categories.

  #二值交叉熵,这里输入要经过sigmoid处理
import torch
import torch.nn as nn
import torch.nn.functional as F
nn.BCELoss(F.sigmoid(input), target)
#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层
nn.CrossEntropyLoss(input, target)

二、Focal loss

insert image description here
He Kaiming's team introduced Focal Loss in the RetinaNet paper to solve the imbalance in the number of difficult and easy samples. Let's review it.
The number of samples and the confidence are penalized, and the loss weight of large samples and the loss weight of high confidence samples are considered to be lower.

class FocalLoss(nn.Module):
   """
   copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
   This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
   'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
       Focal_Loss= -1*alpha*(1-pt)*log(pt)
   :param num_class:
   :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
   :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                   focus on hard misclassified example
   :param smooth: (float,double) smooth value when cross entropy
   :param balance_index: (int) balance class index, should be specific when alpha is float
   :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
   """

   def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
       super(FocalLoss, self).__init__()
       self.apply_nonlin = apply_nonlin
       self.alpha = alpha
       self.gamma = gamma
       self.balance_index = balance_index
       self.smooth = smooth
       self.size_average = size_average

       if self.smooth is not None:
           if self.smooth < 0 or self.smooth > 1.0:
               raise ValueError('smooth value should be in [0,1]')

   def forward(self, logit, target):
       if self.apply_nonlin is not None:
           logit = self.apply_nonlin(logit)
       num_class = logit.shape[1]

       if logit.dim() > 2:
           # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
           logit = logit.view(logit.size(0), logit.size(1), -1)
           logit = logit.permute(0, 2, 1).contiguous()
           logit = logit.view(-1, logit.size(-1))
       target = torch.squeeze(target, 1)
       target = target.view(-1, 1)
       # print(logit.shape, target.shape)
       # 
       alpha = self.alpha

       if alpha is None:
           alpha = torch.ones(num_class, 1)
       elif isinstance(alpha, (list, np.ndarray)):
           assert len(alpha) == num_class
           alpha = torch.FloatTensor(alpha).view(num_class, 1)
           alpha = alpha / alpha.sum()
       elif isinstance(alpha, float):
           alpha = torch.ones(num_class, 1)
           alpha = alpha * (1 - self.alpha)
           alpha[self.balance_index] = self.alpha

       else:
           raise TypeError('Not support alpha type')
       
       if alpha.device != logit.device:
           alpha = alpha.to(logit.device)

       idx = target.cpu().long()

       one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
       one_hot_key = one_hot_key.scatter_(1, idx, 1)
       if one_hot_key.device != logit.device:
           one_hot_key = one_hot_key.to(logit.device)

       if self.smooth:
           one_hot_key = torch.clamp(
               one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
       pt = (one_hot_key * logit).sum(1) + self.smooth
       logpt = pt.log()

       gamma = self.gamma

       alpha = alpha[idx]
       alpha = torch.squeeze(alpha)
       loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt

       if self.size_average:
           loss = loss.mean()
       else:
           loss = loss.sum()
       return loss

1. Dice loss function

insert image description here
Set similarity measure function. It is usually used to calculate the similarity between two samples and belongs to metric learning. X is the real target mask, and Y is the predicted target mask. We always hope that the intersection of X and Y is as large as possible, and the proportion is as large as possible, but the loss needs to gradually decrease, so add a negative sign in front of the ratio.
It can alleviate the negative impact caused by the imbalance of the foreground and background (area) in the sample. The imbalance of the foreground and background means that most areas in the image do not contain the target, and only a small part of the area contains the target. Dice Loss training pays more attention to the mining of the foreground area, which guarantees a lower FN, but there will be a loss saturation problem, while CE Loss calculates the loss of each pixel equally. Therefore, using Dice Loss alone often cannot achieve good results, and it needs to be used in combination, such as Dice Loss+CE Loss or Dice Loss+Focal Loss.

Here is the link to the original text: https://blog.csdn.net/Mike_honor/article/details/125871091

def dice_loss(prediction, target):
    """Calculating the dice loss
    Args:
        prediction = predicted image
        target = Targeted image
    Output:
        dice_loss"""

    smooth = 1.0

    i_flat = prediction.view(-1)
    t_flat = target.view(-1)

    intersection = (i_flat * t_flat).sum()

    return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth))

def calc_loss(prediction, target, bce_weight=0.5):
    """Calculating the loss and metrics
    Args:
        prediction = predicted image
        target = Targeted image
        metrics = Metrics printed
        bce_weight = 0.5 (default)
    Output:
        loss : dice loss of the epoch """
    bce = F.binary_cross_entropy_with_logits(prediction, target)
    prediction = F.sigmoid(prediction)
    dice = dice_loss(prediction, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    return loss

1. IOU loss

insert image description here
The loss function is similar to the Dice loss function, which is measured by metric learning. You can try it in the experiment, and it has a miraculous effect in the convergence of small target segmentation!

def SoftIoULoss( pred, target):
        # Old One
        pred = torch.sigmoid(pred)
        smooth = 1

        # print("pred.shape: ", pred.shape)
        # print("target.shape: ", target.shape)

        intersection = pred * target
        loss = (intersection.sum() + smooth) / (pred.sum() + target.sum() -intersection.sum() + smooth)

        # loss = (intersection.sum(axis=(1, 2, 3)) + smooth) / \
        #        (pred.sum(axis=(1, 2, 3)) + target.sum(axis=(1, 2, 3))
        #         - intersection.sum(axis=(1, 2, 3)) + smooth)

        loss = 1 - loss.mean()
        # loss = (1 - loss).mean()

        return loss

一、TverskyLoss

Segmentation tasks also have different emphases. For example, medical segmentation pays more attention to the recall rate (high sensitivity), that is, the real mask is predicted as much as possible, and it does not pay much attention to whether the prediction mask has more predictions. B is the real mask, and A is the predicted mask. |AB| is false positive, |BA| is false negative, alpha and beta can control the trade-off between false positive and false negative. If we focus more on recall, the effect of |BA| is amplified.
insert image description here
Among them, alpha and beta can affect the retrieval rate and accuracy rate. If we want the target to have a higher recall rate, then we can choose a higher beta.
insert image description here

class TverskyLoss(nn.Module):
   def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
                square=False):
       """
       paper: https://arxiv.org/pdf/1706.05721.pdf
       """
       super(TverskyLoss, self).__init__()

       self.square = square
       self.do_bg = do_bg
       self.batch_dice = batch_dice
       self.apply_nonlin = apply_nonlin
       self.smooth = smooth
       self.alpha = 0.3
       self.beta = 0.7

   def forward(self, x, y, loss_mask=None):
       shp_x = x.shape

       if self.batch_dice:
           axes = [0] + list(range(2, len(shp_x)))
       else:
           axes = list(range(2, len(shp_x)))

       if self.apply_nonlin is not None:
           x = self.apply_nonlin(x)

       tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)


       tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)

       if not self.do_bg:
           if self.batch_dice:
               tversky = tversky[1:]
           else:
               tversky = tversky[:, 1:]
       tversky = tversky.mean()

       return -tversky

Summarize

After a series of experiments, it was found that the latter four loss functions are more suitable for small target segmentation network training. But each task is different, if you have plenty of time, you can try them one by one.

Guess you like

Origin blog.csdn.net/jijiarenxiaoyudi/article/details/128360405
Recommended