ps: Although it cannot be implemented with the NLLLoss function. But at least I finally realized my idea. Now let’s test the final and initial Focal Loss as follows:
import torch
import torch.nn as nn
#二分类
class FocalLoss(nn.Module):
def __init__(self, gamma=2,alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha=alpha
def forward(self, input, target):
# input:size is M*2. M is the batch number
# target:size is M.
pt=torch.softmax(input,dim=1)
p=pt[:,1]
loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\
(1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))
return loss.mean()
import torch
import torch.nn as nn
class FocalLoss2(nn.Module):
def __init__(self, gamma=0, alpha=1):
super(FocalLoss2, self).__init__()
self.gamma = gamma
self.ce = nn.CrossEntropyLoss()
self.alpha = alpha
def forward(self, input, target):
logp = self.ce(input, target)
p = torch.exp(-logp)
loss = (1 - p) ** self.gamma * logp
loss = self.alpha * loss
return loss.mean()
Use code to test the instance:
import torch
from loss import FocalLoss
from loss2 import FocalLoss2
input=torch.Tensor([[ 0.0543, 0.5641],[ 1.2221, -0.5496],[-0.7951, -0.1546],[-0.4557, 1.4724]])
target= torch.Tensor([1,0,1,1])
# input=torch.Tensor([[ 0.0543, 0.5641],[ 1.2221, -0.5496]])
# target= torch.Tensor([1,0])
# input=torch.Tensor([[ 0.0543, 0.5641]])
# target= torch.Tensor([1])
print(torch.softmax(input,dim=1))
criterion = FocalLoss(gamma=2,alpha=0.25)
criterion1 = FocalLoss2(gamma=2,alpha=0.25)
criterion2 = torch.nn.CrossEntropyLoss()
res = criterion(input, target)
print(res)
res1 = criterion1(input, target.long())
print(res1)
res2 = criterion2(input, target.long())
print(res2)
tensor([[0.3752, 0.6248],
[0.8547, 0.1453],
[0.3451, 0.6549],
[0.1270, 0.8730]])
tensor(0.0080)
tensor(0.0049)
tensor(0.2966)
Change target to: target = torch.Tensor([0,1,0,0])
tensor([[0.3752, 0.6248],
[0.8547, 0.1453],
[0.3451, 0.6549],
[0.1270, 0.8730]])
tensor(0.5403)
tensor(0.2289)
tensor(1.5092)
It is found that the trends of Focal Loss and Focal Loss 2 are similar, and the overall order of magnitude is not much different. In fact, Focal Loss can probably be represented by Focal Loss2. It should be no problem if you dare to write it and put it on github. This is my understanding and I hope it will be helpful to your understanding.