Focal Loss loss function

Table of contents

Preface

Cross entropy loss function

Balanced cross-entropy

Focal Loss

Code


Preface

Focal loss is a commonly used loss function to solve the problem of category imbalance. It was proposed by He Kaiming (paper name: Focal Loss for Dense Object Detection). It is used in the image field to solve the extremely imbalanced and difficult problem of positive and negative samples in one-stage target detection. Classification sample learning problem. This article starts from the cross-entropy loss function, analyzes the sample imbalance problem, compares focal loss with the cross-entropy loss function, and gives an explanation of the effectiveness of focal loss.

Cross entropy loss function

Cross-entropy is usually used for classification, and Focal Loss is also improved based on cross-entropy. First, introducing the principle of cross-entropy will make it easier to understand Focal Loss.

Binary classification cross entropy loss function, the formula is defined as follows:

 It is now defined as follows p_{t}:

 The deformed loss function is obtained as follows:

Balanced cross-entropy

Due to the extremely imbalanced problem of positive and negative samples, directly using the cross-entropy loss function will not achieve good results. Therefore, first balance the cross entropy.

Generally, in order to solve the problem of category imbalance, a weight factor \alpha _{i} ∈ [0, 1] is added before each category in the loss function to coordinate the category imbalance. Define it in p_{t}a similar way \alpha _{t}to get the binary-class balanced cross-entropy loss function:

Balanced cross-entropy \alphabalances the importance of positive and negative samples, but does not distinguish between difficult and easy samples. Then, a large imbalance between classes will cause the cross-entropy loss to be affected during training. The loss of misclassification of easily classified samples accounts for the vast majority of the overall loss and dominates the gradient, which will overwhelm the cross-entropy loss function.

Focal Loss

Focal Loss, based on the balanced cross-entropy loss function, adds an adjustment factor to reduce the weight of easy-to-classify samples and focus on the training of difficult samples. It is defined as follows:

 \alphaWeights help deal with class imbalance.

 Among them, (1-p_{t})^{\gamma }is the adjustment factor, \gamma≥ 0 is the adjustable focus parameter, and the following figure shows \gamma the focal loss curve at different values ​​of ∈ [0, 5].

γ Control the shape of the curve.  γThe larger the value, the smaller the loss of well-classified samples, and we can focus the model's attention on those samples that are difficult to classify. A large value  γ expands the range of samples that can obtain small loss. At the same γ=0time, at that time , this expression degenerated into Cross Entropy Loss (cross entropy loss function).

In the image above, the "blue" line represents the cross-entropy loss. The x-axis is the "probability of prediction as the true label" (for simplicity, let's call it pt). The Y-axis is the value of Focal loss and CE loss after a given pt.

As can be seen from the image, when the probability of the model predicting a true label is around 0.6, the cross-entropy loss is still around 0.5. Therefore, in order to reduce the loss during training, our model will have to predict the true label with a higher probability. In other words, cross-entropy loss requires the model to be very confident in its predictions. But this will also have a negative impact on model performance.

Deep learning models can become overconfident, so the model's generalization ability decreases.

When using Focal Loss with γ > 1, you can reduce the training loss of "well-classified samples" or "samples with a high probability of correct prediction by the model". However, for "difficult-to-classify samples", such as those with a prediction probability less than 0.5, it will not It will reduce too much loss.

Focal Loss Features:

  • When p_{t}it is very small (the sample is difficult to separate, regardless of whether the classification is correct), the adjustment factor approaches 1, and the weight of the sample in the loss function is not affected; when it is large (the sample is easy to p_{t}separate, regardless of whether the classification is correct), the adjustment factor tends to Near 0, the weight of the sample in the loss function drops a lot
  • The focus parameter \gammacan adjust the degree of reduction of the weight of easily classified samples. \gammaThe larger the weight, the greater the degree of reduction.

By analyzing the characteristics of the Focal Loss function, we can see that this loss function reduces the weight of easy-to-classify samples and focuses on difficult-to-classify samples.

Code

class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"    
    def __init__(self, alpha=.25, gamma=2):
            super(WeightedFocalLoss, self).__init__()        
            self.alpha = torch.tensor([alpha, 1-alpha]).cuda()        
            self.gamma = gamma
            
    def forward(self, inputs, targets):
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')        
            targets = targets.type(torch.long)        
            at = self.alpha.gather(0, targets.data.view(-1))        
            pt = torch.exp(-BCE_loss)        
            F_loss = at*(1-pt)**self.gamma * BCE_loss        
            return F_loss.mean()

Mainly refer to this article: Detailed explanation of Focal Loss loss function

This article: popular explanation of focal loss 

Interpretation of the paper that proposed it: Detailed explanation of Focal loss paper

Guess you like

Origin blog.csdn.net/m0_70813473/article/details/131432473