Focal Loss classification problem pytorch implementation code (simple implementation)

ps: Because there is a huge gap in the number of positive and negative samples in this step. There are more than 1,500 positive samples and more than 750,000 negative samples. To use Focal Loss to solve this problem.

First of all, thank Code_Mart 's blog for summarizing the theory https://blog.csdn.net/Code_Mart/article/details/89736187 . And realized and explained the two-category and multi-category codes of Focal Loss. At the same time, his discussion with xwmwanjy666 clarified some details.

But I feel that the code does not conform to the pytorch 0.4.1 version, and found in the conversation between them https://github.com/ronghuaiyang/arcface-pytorch/blob/master/models/focal_loss.py I changed it slightly .

import torch
import torch.nn as nn


class FocalLoss(nn.Module):

    def __init__(self, gamma=0,alpha=1):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss()
        self.alpha=alpha
    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        loss = self.alpha*loss
        return loss.mean()

It's so simple, is this what I want, but then I continued to read their discussion and compared it and found that it was not done (ground truth is 1, alpha=a; when ground truth is 0, alpha=1-a). It made me uncomfortable, and then I spent more than 4 hours searching on github, and found that basically either did not consider this issue, or the code was very complicated (not understandable, the input did not meet my requirements) and it was not applied to classification On the problem, and the version of the function pytorch used is generally below 0.4.

Until I found out https://github.com/louis-she/focal-loss.pytorch/blob/master/focal_loss.py

import torch
import torch.nn.functional as F


class BCEFocalLoss(torch.nn.Module):

    def __init__(self, gamma=2, alpha=None, reduction='elementwise_mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, _input, target):
        pt = torch.sigmoid(_input)
        loss = - (1 - pt) ** self.gamma * target * torch.log(pt) - \
            pt ** self.gamma * (1 - target) * torch.log(1 - pt)
        if self.alpha:
            loss = loss * self.alpha
        if self.reduction == 'elementwise_mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss

loss =-(1-pt) ** self.gamma * target * torch.log(pt) -pt ** self.gamma * (1-target) * torch.log(1-pt). This can also solve the above problem. I don't quite understand that the author did not change the corresponding (ground truth is 1, alpha=a; when ground truth is 0, alpha=1-a). But using this idea for reference, the first code is changed to the following:

import torch
import torch.nn as nn


class FocalLoss(nn.Module):

    def __init__(self, gamma=2,alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss()
        self.alpha=alpha
    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = self.alpha*(1 - p) ** self.gamma * logp * target.long() + \
               (1-self.alpha)*(p) ** self.gamma * logp * (1-target.long())
        return loss.mean()

I don't know if there is a problem with writing this way.

ps added: June 14, 2019:

Use Focal Loss classification problem pytorch to implement the test code test of code continuation 3 This idea is used to change the code of the first code (the last code above)

The intermediate result is the above loss:

input=torch.Tensor([[ 0.0543,  0.5641],[ 1.2221, -0.5496],[-0.7951, -0.1546],[-0.4557,  1.4724]])
target= torch.Tensor([1,0,1,1])

tensor([[0.3752, 0.6248],
        [0.8547, 0.1453],
        [0.3451, 0.6549],
        [0.1270, 0.8730]])
tensor(0.0080)
tensor(0.0344)
tensor(0.2966)



target= torch.Tensor([0,1,0,0])

tensor([[0.3752, 0.6248],
        [0.8547, 0.1453],
        [0.3451, 0.6549],
        [0.1270, 0.8730]])
tensor(0.5403)
tensor(0.0987)
tensor(1.5092)

It can be seen from the above results that the effect is not good. The first label corresponds to the predicted probability, so the loss should be small, and the second label is opposite to the predicted probability, so the loss should be large. Although the trend is consistent. But the relative multiples are about 70 times, 3 times, and 5 times respectively. It might as well use the built-in loss function. So finally choose to use Focal Loss classification problem pytorch to realize the conclusion code in code continuation 3:

import torch
import torch.nn as nn
 
#二分类
class FocalLoss(nn.Module):
 
    def __init__(self, gamma=2,alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha=alpha
    def forward(self, input, target):
        # input:size is M*2. M is the batch number
        # target:size is M.
        pt=torch.softmax(input,dim=1)
        p=pt[:,1]
        loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\
               (1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))
        return loss.mean()

 

ps2020.11.12 update

The focalloss used in the previous two-class model:

class FocalLossV1(nn.Module):

    def __init__(self,
                 alpha=0.25,
                 gamma=2,
                 reduction='mean',):
        super(FocalLossV1, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.crit = nn.BCEWithLogitsLoss(reduction='none')

        self.celoss = torch.nn.CrossEntropyLoss(reduction='none')
    def forward(self, logits, label):
        '''
        args:
            logits: tensor of shape (N, ...)
            label: tensor of shape(N, ...)
        '''

        # compute loss
        logits = logits.float() # use fp32 if logits is fp16
        with torch.no_grad():
            alpha = torch.empty_like(logits).fill_(1 - self.alpha)
            alpha[label == 1] = self.alpha
        ce_loss=(-(label * torch.log(logits)) - (
                    (1 - label) * torch.log(1 - logits)))
        # ce_loss=(-(label * torch.log(torch.softmax(logits, dim=1))) - (
        #             (1 - label) * torch.log(1 - torch.softmax(logits, dim=1))))
        pt = torch.where(label == 1, logits, 1 - logits)
        # ce_loss = self.crit(logits, label)
        loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss)
        if self.reduction == 'mean':
            loss = loss.mean()
        if self.reduction == 'sum':
            loss = loss.sum()
        return loss

 

Guess you like

Origin blog.csdn.net/qq_36401512/article/details/91450292