Focal Loss classification problem pytorch implementation code (continued 3)

ps: Although it cannot be implemented with the NLLLoss function. But at least I finally realized my idea. Now let’s test the final and initial Focal Loss as follows:

import torch
import torch.nn as nn
 
#二分类
class FocalLoss(nn.Module):
 
    def __init__(self, gamma=2,alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha=alpha
    def forward(self, input, target):
        # input:size is M*2. M is the batch number
        # target:size is M.
        pt=torch.softmax(input,dim=1)
        p=pt[:,1]
        loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\
               (1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))
        return loss.mean()
import torch
import torch.nn as nn


class FocalLoss2(nn.Module):

    def __init__(self, gamma=0, alpha=1):
        super(FocalLoss2, self).__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss()
        self.alpha = alpha

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        loss = self.alpha * loss
        return loss.mean()

Use code to test the instance:

import torch
from loss import FocalLoss
from loss2 import FocalLoss2


input=torch.Tensor([[ 0.0543,  0.5641],[ 1.2221, -0.5496],[-0.7951, -0.1546],[-0.4557,  1.4724]])
target= torch.Tensor([1,0,1,1])

# input=torch.Tensor([[ 0.0543,  0.5641],[ 1.2221, -0.5496]])
# target= torch.Tensor([1,0])

# input=torch.Tensor([[ 0.0543,  0.5641]])
# target= torch.Tensor([1])



print(torch.softmax(input,dim=1))

criterion = FocalLoss(gamma=2,alpha=0.25)
criterion1 = FocalLoss2(gamma=2,alpha=0.25)
criterion2 = torch.nn.CrossEntropyLoss()


res = criterion(input, target)
print(res)
res1 = criterion1(input, target.long())
print(res1)
res2 = criterion2(input, target.long())
print(res2)



tensor([[0.3752, 0.6248],
        [0.8547, 0.1453],
        [0.3451, 0.6549],
        [0.1270, 0.8730]])
tensor(0.0080)
tensor(0.0049)
tensor(0.2966)

Change target to: target = torch.Tensor([0,1,0,0])

tensor([[0.3752, 0.6248],
        [0.8547, 0.1453],
        [0.3451, 0.6549],
        [0.1270, 0.8730]])
tensor(0.5403)
tensor(0.2289)
tensor(1.5092)

It is found that the trends of Focal Loss and Focal Loss 2 are similar, and the overall order of magnitude is not much different. In fact, Focal Loss can probably be represented by Focal Loss2. It should be no problem if you dare to write it and put it on github. This is my understanding and I hope it will be helpful to your understanding.

Guess you like

Origin blog.csdn.net/qq_36401512/article/details/91969862
Recommended