Commonly used classification loss CE Loss, Focal Loss and GHMC Loss understanding and summary

一、CE Loss

definition

Cross-Entropy Loss (Cross-Entropy Loss, CE Loss) can measure the degree of difference between two different probability distributions in the same random variable. When the two probability distributions are closer, the smaller the cross-entropy loss, the more accurate the model prediction result is. .

official

Two categories

The CE Loss formula for the two categories is as follows,

Among them, M: number of positive samples, N: number of negative samples, y_{i}: true value,  p_{i}: predicted value

multi-category

When calculating the CE Loss of multi-classification, it is first necessary to perform softmax processing on the output results of the model. The formula is as follows,

Among them,  output: model output, p: the value after softmax processing on the model output, ​​​​​​: one hot encoding of the real value​ (assuming that the model is doing 5 classifications, if y_{i}=2, then = [0,0,1 ,0,0])

Code

Two categories

import torch
import torch.nn as nn
import math

criterion = nn.BCELoss()
output = torch.rand(1, requires_grad=True)
label = torch.randint(0, 1, (1,)).float()
loss = criterion(output, label)

print("预测值:", output)
print("真实值:", label)
print("nn.BCELoss:", loss)

for i in range(label.shape[0]):
    if label[i] == 0:
        res = -math.log(1-output[i])
    elif label[i] == 1:
        res = -math.log(output[i])
print("自己的计算结果", res)


"""
预测值: tensor([0.7359], requires_grad=True)
真实值: tensor([0.])
nn.BCELoss: tensor(1.3315, grad_fn=<BinaryCrossEntropyBackward0>)
自己的计算结果 1.331509556677378
"""

multi-category

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("预测值:", output)
print("真实值:", label)
print("nn.CrossEntropyLoss:", loss)

output = torch.softmax(output, dim=1)
print("softmax后的预测值:", output)

one_hot = torch.zeros_like(output).scatter_(1, label.view(-1, 1), 1)
print("真实值对应的one_hot编码", one_hot)

res = (-torch.log(output) * one_hot).sum()
print("自己的计算结果", res)


"""
预测值: tensor([[-0.7459, -0.3963, -1.8046,  0.6815,  0.2965]], requires_grad=True)
真实值: tensor([1])
nn.CrossEntropyLoss: tensor(1.9296, grad_fn=<NllLossBackward0>)
softmax后的预测值: tensor([[0.1024, 0.1452, 0.0355, 0.4266, 0.2903]], grad_fn=<SoftmaxBackward0>)
真实值对应的one_hot编码 tensor([[0., 1., 0., 0., 0.]])
自己的计算结果 tensor(1.9296, grad_fn=<SumBackward0>)
"""

二、Focal Loss

definition

Although CE Loss can measure the degree of difference between two different probability distributions in the same random variable, it cannot solve the following two problems: 1. The problem of imbalance in the number of positive and negative samples (such as the classification branch of centernet, which only uses the target The central point is used as a positive sample, and the other pixels on the feature map are used as negative samples. It is conceivable that the number of positive and negative samples is very different); The loss accounts for the vast majority of the overall loss and dominates the gradient)

In order to solve the above problems, Focal Loss is improved on the basis of CE Loss, and introduces: 1. The adjustment factor of the number of positive and negative samples to solve the problem of imbalance in the number of positive and negative samples; 2. The classification adjustment factor of difficult and easy samples to focus on difficult-to-classify sample

official

Two categories

The formula is as follows,

 

​​​​​​​

Among them, \alpha: positive and negative sample size adjustment factor, \gamma: difficult and easy sample classification adjustment factor

multi-category

Among them, \alpha _{y_{i}}: y_{i}the weight of the category

Code

Two categories

def sigmoid_focal_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = -1,
    gamma: float = 2,
    reduction: str = "none",
) -> torch.Tensor:
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
    Returns:
        Loss tensor with the reduction option applied.
    """
    inputs = inputs.float()
    targets = targets.float()
    p = torch.sigmoid(inputs)
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()

    return loss

Step 1, first perform sigmoid processing on the input,

p = torch.sigmoid(inputs)

Step 2, then calculate the CE Loss,

ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

Step 3, definition p_{t}^{i}, the formula is:

p_t = p * targets + (1 - p) * (1 - targets)

Step 4. Add difficult and easy sample classification adjustment factors for CE Loss,

loss = ce_loss * ((1 - p_t) ** gamma)

Step 5, definition \alpha _{t}^{i}, the formula is:

alpha_t = alpha * targets + (1 - alpha) * (1 - targets)

Step 6. Add positive and negative sample size adjustment factors to the loss in step 4,

loss = alpha_t * loss

multi-category

def multi_cls_focal_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: torch.Tensor,
    gamma: float = 2,
    reduction: str = "none",
) -> torch.Tensor:
    
    inputs = inputs.float()
    targets = targets.float()
    ce_loss = nn.CrossEntropyLoss()(inputs, targets, reduction="none")
    one_hot = torch.zeros_like(inputs).scatter_(1, targets.view(-1, 1), 1)
    p_t = inputs * one_hot
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * one_hot
        loss = alpha_t * loss

    return loss

3. GHMC Loss

definition

After Focal Loss is improved on the basis of CE Loss, it solves the problem of unbalanced positive and negative samples and the inability to distinguish between difficult and easy samples, but it also pays too much attention to samples that are difficult to classify (outliers), resulting in crooked modeling. In order to solve this problem, GHMC (Gradient Harmonizing Mechanism-C) defines the gradient modulus length, which is proportional to the difficulty of classification . The purpose is to let the model not pay attention to those samples that are easy to learn, and do not pay attention to those that are particularly difficult to classify. sample of

official

1. Define the gradient modulus length

The CE Loss formula for the two categories is as follows,

Assuming x is the output of the model, assuming p=sigmoid(x), find the partial derivative of the loss to x,

Therefore, the gradient modulus is defined as follows,

Among them,  p: predicted value, p^{\ast }: actual value

The relationship between the gradient modulus length and the number of samples is as follows,

2. Define the gradient density ( the number of samples on the unit gradient modulus length g )

  

Among them, g_{k}: the gradient modulus length of the kth sample, \delta _{\varepsilon }(g_{k},g): the number of samples in the range, g_{k}: the length of the interval(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})l_{\varepsilon }(g)(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})

3. Define the gradient density harmonizing parameter

Among them, N: the total number of samples

 4. Define GHMC Loss

 

Code

def _expand_binary_labels(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    inds = torch.nonzero(labels >= 1).squeeze()
    if inds.numel() > 0:
        bin_labels[inds, labels[inds] - 1] = 1
    bin_label_weights = label_weights.view(-1, 1).expand(
        label_weights.size(0), label_channels)
    return bin_labels, bin_label_weights


class GHMC(nn.Module):
    def __init__(
            self,
            bins=10,
            momentum=0,
            use_sigmoid=True,
            loss_weight=1.0):
        super(GHMC, self).__init__()
        self.bins = bins
        self.momentum = momentum
        self.edges = [float(x) / bins for x in range(bins+1)]
        self.edges[-1] += 1e-6
        if momentum > 0:
            self.acc_sum = [0.0 for _ in range(bins)]
        self.use_sigmoid = use_sigmoid
        self.loss_weight = loss_weight

    def forward(self, pred, target, label_weight, *args, **kwargs):
        """ Args:
        pred [batch_num, class_num]:
            The direct prediction of classification fc layer.
        target [batch_num, class_num]:
            Binary class target for each sample.
        label_weight [batch_num, class_num]:
            the value is 1 if the sample is valid and 0 if ignored.
        """
        if not self.use_sigmoid:
            raise NotImplementedError
        # the target should be binary class label
        if pred.dim() != target.dim():
            target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1))
        target, label_weight = target.float(), label_weight.float()
        edges = self.edges
        mmt = self.momentum
        weights = torch.zeros_like(pred)

        # 计算梯度模长
        g = torch.abs(pred.sigmoid().detach() - target)

        valid = label_weight > 0
        tot = max(valid.float().sum().item(), 1.0)
        
        # 设置有效区间个数
        n = 0
        for i in range(self.bins):
            inds = (g >= edges[i]) & (g < edges[i+1]) & valid
            num_in_bin = inds.sum().item()
            if num_in_bin > 0:
                if mmt > 0:
                    self.acc_sum[i] = mmt * self.acc_sum[i] \
                        + (1 - mmt) * num_in_bin
                    weights[inds] = tot / self.acc_sum[i]
                else:
                    weights[inds] = tot / num_in_bin
                n += 1
        if n > 0:
            weights = weights / n

        loss = F.binary_cross_entropy_with_logits(
            pred, target, weights, reduction='sum') / tot
        return loss * self.loss_weight

Step 1. Divide the gradient modulus length into bins (default is 10) regions,

self.edges = [float(x) / bins for x in range(bins+1)]
"""
[0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000, 0.9000, 1.0000]
"""

Step 2. Calculate the gradient modulus length

g = torch.abs(pred.sigmoid().detach() - target)

Step 3. Calculate the number of gradient moduli that fall into different bin intervals

valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0
for i in range(self.bins):
    inds = (g >= edges[i]) & (g < edges[i+1]) & valid
    num_in_bin = inds.sum().item()
    if num_in_bin > 0:
        if mmt > 0:
            self.acc_sum[i] = mmt * self.acc_sum[i] + (1 - mmt) * num_in_bin
            weights[inds] = tot / self.acc_sum[i]
        else:
            weights[inds] = tot / num_in_bin
        n += 1
if n > 0:
    weights = weights / n

Step 4. Calculate GHMC Loss

loss = F.binary_cross_entropy_with_logits(pred, target, weights, reduction='sum') / tot * self.loss_weight

【Reference article】

Understanding of Focal Loss and its use on multi-classification tasks (Pytorch)_focal loss multi-classification_GHZhao_GIS_RS's Blog-CSDN Blog

Popular explanation of focal loss - Zhihu

Focal Loss loss function (super detailed interpretation) - BigHao688's Blog - CSDN Blog

Understand Focal Loss and GHM in 5 minutes-a weapon to solve sample imbalance-Knowledge 

Guess you like

Origin blog.csdn.net/qq_38964360/article/details/131632932