Focal Loss 分类问题 pytorch实现代码(续3)

ps:虽然无法用NLLLoss函数来实现.但好歹最后实现了自己的想法.现在再来测试下最后和最开始的Focal Loss如下:

import torch
import torch.nn as nn
 
#二分类
class FocalLoss(nn.Module):
 
    def __init__(self, gamma=2,alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha=alpha
    def forward(self, input, target):
        # input:size is M*2. M is the batch number
        # target:size is M.
        pt=torch.softmax(input,dim=1)
        p=pt[:,1]
        loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\
               (1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))
        return loss.mean()
import torch
import torch.nn as nn


class FocalLoss2(nn.Module):

    def __init__(self, gamma=0, alpha=1):
        super(FocalLoss2, self).__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss()
        self.alpha = alpha

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        loss = self.alpha * loss
        return loss.mean()

用代码去测试实例:

import torch
from loss import FocalLoss
from loss2 import FocalLoss2


input=torch.Tensor([[ 0.0543,  0.5641],[ 1.2221, -0.5496],[-0.7951, -0.1546],[-0.4557,  1.4724]])
target= torch.Tensor([1,0,1,1])

# input=torch.Tensor([[ 0.0543,  0.5641],[ 1.2221, -0.5496]])
# target= torch.Tensor([1,0])

# input=torch.Tensor([[ 0.0543,  0.5641]])
# target= torch.Tensor([1])



print(torch.softmax(input,dim=1))

criterion = FocalLoss(gamma=2,alpha=0.25)
criterion1 = FocalLoss2(gamma=2,alpha=0.25)
criterion2 = torch.nn.CrossEntropyLoss()


res = criterion(input, target)
print(res)
res1 = criterion1(input, target.long())
print(res1)
res2 = criterion2(input, target.long())
print(res2)



tensor([[0.3752, 0.6248],
        [0.8547, 0.1453],
        [0.3451, 0.6549],
        [0.1270, 0.8730]])
tensor(0.0080)
tensor(0.0049)
tensor(0.2966)

改变target为:target= torch.Tensor([0,1,0,0])

tensor([[0.3752, 0.6248],
        [0.8547, 0.1453],
        [0.3451, 0.6549],
        [0.1270, 0.8730]])
tensor(0.5403)
tensor(0.2289)
tensor(1.5092)

发现Focal Loss和Focal Loss2趋势相类似,且数量级总体相差不大.其实Focal Loss大概可以用Focal Loss2表示,敢写出来放到github上应该没什么问题.这是我的理解,希望对你的理解有帮助.

猜你喜欢

转载自blog.csdn.net/qq_36401512/article/details/91969862
今日推荐