분할 네트워크 손실 기능 요약! 교차 엔트로피, 초점 손실, 주사위, iou, TverskyLoss!


머리말

세분화 네트워크 작업의 실제 훈련 과정에서 손실 함수의 선택은 특히 중요합니다. Semantic Segmentation의 경우 Positive와 Negative Sample 사이의 불균형이나 Category의 불균형이 있을 가능성이 높으므로 적절한 Loss Function을 선택하는 것이 모델 수렴과 정확한 예측에 중요한 역할을 합니다.


1. 교차 엔트로피 손실

여기에 이미지 설명 삽입
M은 범주의 개수,
yic는 요소가 속한 범주를 나타내는 지시함수,
pic는 예측확률, 관찰된 표본이 범주 c에 속할 예측확률이며, 예측확률은 전진;

단점:
Cross-entropy Loss는 대부분의 시맨틱 분할 시나리오에서 사용할 수 있지만 전경 및 배경만 분할에 사용되는 경우 전경 픽셀 수가 배경 픽셀 수보다 훨씬 작은 경우 명백한 단점이 있습니다. , 즉 배경 요소의 수가 전경 요소의 수보다 훨씬 많고 배경 요소의 손실 함수 구성 요소가 우세하여 모델이 배경에 크게 편향되어 모델 교육 및 예측 결과가 좋지 않습니다.

마찬가지로 BCEloss도 이 문제에 직면하고 있는데, BCEloss는 다음과 같습니다.
여기에 이미지 설명 삽입
모든 N 범주에 대해 이진 분류 손실 계산을 수행합니다.

  #二值交叉熵,这里输入要经过sigmoid处理
import torch
import torch.nn as nn
import torch.nn.functional as F
nn.BCELoss(F.sigmoid(input), target)
#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层
nn.CrossEntropyLoss(input, target)

二, 초점 손실

여기에 이미지 설명 삽입
He Kaiming 팀은 어려운 샘플과 쉬운 샘플의 수의 불균형을 해결하기 위해 RetinaNet 논문에서 Focal Loss를 도입했습니다.
표본의 수와 신뢰도에 벌점을 부여하고 큰 표본의 손실 가중치와 신뢰도가 높은 표본의 손실 가중치는 더 낮은 것으로 간주합니다.

class FocalLoss(nn.Module):
   """
   copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
   This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
   'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
       Focal_Loss= -1*alpha*(1-pt)*log(pt)
   :param num_class:
   :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
   :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                   focus on hard misclassified example
   :param smooth: (float,double) smooth value when cross entropy
   :param balance_index: (int) balance class index, should be specific when alpha is float
   :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
   """

   def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
       super(FocalLoss, self).__init__()
       self.apply_nonlin = apply_nonlin
       self.alpha = alpha
       self.gamma = gamma
       self.balance_index = balance_index
       self.smooth = smooth
       self.size_average = size_average

       if self.smooth is not None:
           if self.smooth < 0 or self.smooth > 1.0:
               raise ValueError('smooth value should be in [0,1]')

   def forward(self, logit, target):
       if self.apply_nonlin is not None:
           logit = self.apply_nonlin(logit)
       num_class = logit.shape[1]

       if logit.dim() > 2:
           # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
           logit = logit.view(logit.size(0), logit.size(1), -1)
           logit = logit.permute(0, 2, 1).contiguous()
           logit = logit.view(-1, logit.size(-1))
       target = torch.squeeze(target, 1)
       target = target.view(-1, 1)
       # print(logit.shape, target.shape)
       # 
       alpha = self.alpha

       if alpha is None:
           alpha = torch.ones(num_class, 1)
       elif isinstance(alpha, (list, np.ndarray)):
           assert len(alpha) == num_class
           alpha = torch.FloatTensor(alpha).view(num_class, 1)
           alpha = alpha / alpha.sum()
       elif isinstance(alpha, float):
           alpha = torch.ones(num_class, 1)
           alpha = alpha * (1 - self.alpha)
           alpha[self.balance_index] = self.alpha

       else:
           raise TypeError('Not support alpha type')
       
       if alpha.device != logit.device:
           alpha = alpha.to(logit.device)

       idx = target.cpu().long()

       one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
       one_hot_key = one_hot_key.scatter_(1, idx, 1)
       if one_hot_key.device != logit.device:
           one_hot_key = one_hot_key.to(logit.device)

       if self.smooth:
           one_hot_key = torch.clamp(
               one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
       pt = (one_hot_key * logit).sum(1) + self.smooth
       logpt = pt.log()

       gamma = self.gamma

       alpha = alpha[idx]
       alpha = torch.squeeze(alpha)
       loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt

       if self.size_average:
           loss = loss.mean()
       else:
           loss = loss.sum()
       return loss

1. 주사위 손실 기능

여기에 이미지 설명 삽입
유사도 측정 기능을 설정합니다. 일반적으로 두 샘플 간의 유사성을 계산하는 데 사용되며 메트릭 학습에 속합니다. X는 실제 타겟 마스크, Y는 예측 타겟 마스크 항상 X와 Y의 교집합이 최대한 크고 비율이 최대한 크길 바라지만 손실은 점진적으로 줄여야 하므로 추가 비율 앞에 음수 부호.
샘플의 전경과 배경(영역)의 불균형으로 인한 부정적인 영향을 완화할 수 있습니다 전경과 배경의 불균형은 이미지의 대부분의 영역에 대상이 포함되지 않고 영역의 일부만 포함됨을 의미합니다 대상을 포함합니다. Dice Loss 훈련은 낮은 FN을 보장하는 전경 영역의 마이닝에 더 많은 관심을 기울이지만 손실 포화 문제가 있는 반면 CE Loss는 각 픽셀의 손실을 동일하게 계산합니다. 따라서 Dice Loss만으로는 좋은 결과를 얻을 수 없는 경우가 많으며, Dice Loss+CE Loss 또는 Dice Loss+Focal Loss와 같이 조합하여 사용해야 합니다.

원본 텍스트에 대한 링크는 다음과 같습니다. https://blog.csdn.net/Mike_honor/article/details/125871091

def dice_loss(prediction, target):
    """Calculating the dice loss
    Args:
        prediction = predicted image
        target = Targeted image
    Output:
        dice_loss"""

    smooth = 1.0

    i_flat = prediction.view(-1)
    t_flat = target.view(-1)

    intersection = (i_flat * t_flat).sum()

    return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth))

def calc_loss(prediction, target, bce_weight=0.5):
    """Calculating the loss and metrics
    Args:
        prediction = predicted image
        target = Targeted image
        metrics = Metrics printed
        bce_weight = 0.5 (default)
    Output:
        loss : dice loss of the epoch """
    bce = F.binary_cross_entropy_with_logits(prediction, target)
    prediction = F.sigmoid(prediction)
    dice = dice_loss(prediction, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    return loss

1. IOU 손실

여기에 이미지 설명 삽입
손실함수는 미터법 학습으로 측정되는 다이스 손실함수와 유사하며, 실험에서 사용해 볼 수 있고, 작은 목표 세분화의 수렴에 기적 같은 효과가 있습니다!

def SoftIoULoss( pred, target):
        # Old One
        pred = torch.sigmoid(pred)
        smooth = 1

        # print("pred.shape: ", pred.shape)
        # print("target.shape: ", target.shape)

        intersection = pred * target
        loss = (intersection.sum() + smooth) / (pred.sum() + target.sum() -intersection.sum() + smooth)

        # loss = (intersection.sum(axis=(1, 2, 3)) + smooth) / \
        #        (pred.sum(axis=(1, 2, 3)) + target.sum(axis=(1, 2, 3))
        #         - intersection.sum(axis=(1, 2, 3)) + smooth)

        loss = 1 - loss.mean()
        # loss = (1 - loss).mean()

        return loss

一、TverskyLoss

분할 작업도 강조점이 다릅니다.예를 들어 의료 분할은 재현율(높은 민감도)에 더 많은 관심을 기울입니다. 즉, 실제 마스크는 가능한 한 많이 예측되며 예측 마스크가 더 많은 예측. B는 실제 마스크이고 A는 예측된 마스크입니다. |AB|는 위양성, |BA|는 위음성, 알파와 베타는 위양성과 위음성 간의 균형을 제어할 수 있습니다. 회상에 더 집중하면 |BA|의 효과가 증폭됩니다.
여기에 이미지 설명 삽입
그 중 알파와 베타는 검색률과 정확도에 영향을 줄 수 있는데, 타겟의 재현율을 높이려면 더 높은 베타를 선택할 수 있습니다.
여기에 이미지 설명 삽입

class TverskyLoss(nn.Module):
   def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
                square=False):
       """
       paper: https://arxiv.org/pdf/1706.05721.pdf
       """
       super(TverskyLoss, self).__init__()

       self.square = square
       self.do_bg = do_bg
       self.batch_dice = batch_dice
       self.apply_nonlin = apply_nonlin
       self.smooth = smooth
       self.alpha = 0.3
       self.beta = 0.7

   def forward(self, x, y, loss_mask=None):
       shp_x = x.shape

       if self.batch_dice:
           axes = [0] + list(range(2, len(shp_x)))
       else:
           axes = list(range(2, len(shp_x)))

       if self.apply_nonlin is not None:
           x = self.apply_nonlin(x)

       tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)


       tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)

       if not self.do_bg:
           if self.batch_dice:
               tversky = tversky[1:]
           else:
               tversky = tversky[:, 1:]
       tversky = tversky.mean()

       return -tversky

요약하다

일련의 실험 후, 후자의 네 가지 손실 함수가 작은 대상 세분화 네트워크 훈련에 더 적합하다는 것이 밝혀졌습니다. 하지만 각 작업은 다르므로 시간이 충분하면 하나씩 시도해 볼 수 있습니다.

Supongo que te gusta

Origin blog.csdn.net/jijiarenxiaoyudi/article/details/128360405
Recomendado
Clasificación