Focal Loss 分类问题 pytorch实现代码(续1)

ps:感谢Code_Mart的解答,肯定了思路,不过他也不确定是否可以在pytorch中那么写.事情这样模棱两可让我很烦躁决定深究一下.看到博客https://blog.csdn.net/qq_22210253/article/details/85229988对CrossEntropyLoss的实测决定二分类的上再实测一下理解.

在图片二分类时,输入m张图片,输出一个m*2的Tensor(跟我的模型输出一样)。实际输入3张图片,分二类,最后的输出是一个3*2的Tensor,举例如下:

>>> input=torch.randn(3,2)
>>> input
tensor([[ 0.0082,  1.2996],
        [ 0.1396,  0.4143],
        [-1.6190,  1.1246]])

假设第一列是neg类,第二列是pos类.

然后对每一行使用Softmax,这样可以得到每张图片的概率分布.

>>> sm=torch.nn.Softmax(dim=1)
>>> sm(input)
tensor([[0.2156, 0.7844],
        [0.4318, 0.5682],
        [0.0604, 0.9396]])

然后对Softmax的结果取自然对数:

>>> torch.log(sm(input))
tensor([[-1.5343, -0.2429],
        [-0.8399, -0.5652],
        [-2.8060, -0.0623]])

Softmax后的数值都在0~1之间,所以ln之后值域是负无穷到0。
NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。

假设我们现在Target是[1,0,1](第一张图片是pos,第二张是neg,第三张是pos)。第一行取第1个元素,第二行取第0个,第三行取第1个,去掉负号,结果是:[0.2429,0.8399,0.0623]。再求个均值,结果是:
 

>>> (0.2429+0.8399+0.0623)/3
0.3817

先用NLLLoss函数实验一下:

>>> loss=torch.nn.NLLLoss()
>>> target = torch.tensor([1,0,1])
>>> loss(torch.log(sm(input)),target)
tensor(0.3817)

再用CrossEntropyLoss实验一下

>>> celoss=torch.nn.CrossEntropyLoss()
>>> celoss(input,target)
tensor(0.3817)

果然如此,真的感谢堆排序宝宝作者给我这么直观的感觉.

现在,再次回过头看看昨天的博客里的代码好像有点问题了.先贴下昨天的

import torch
import torch.nn as nn
 
 
class FocalLoss(nn.Module):
 
    def __init__(self, gamma=2,alpha=0.25):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss()
        self.alpha=alpha
    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = self.alpha*(1 - p) ** self.gamma * logp * target + \
               (1-self.alpha)*(p) ** self.gamma * logp * (1-target)
        return loss.mean()

其中p = torch.exp(-logp)是不是有点问题呢.继续上面的例子:

>>> logp=celoss(input,target)
>>> logp
tensor(0.3817)
>>> torch.exp(-logp)
tensor(0.6827)

感觉根本不能对应回去啊,我理解不了.我就换一下极端一点的数字如下测试

>>> input2=torch.tensor([[-1.8,1.8],[1.3,-1.2],[-1.6,1.5]])
>>> sm(input2)
tensor([[0.0266, 0.9734],
        [0.9241, 0.0759],
        [0.0431, 0.9569]])
>>> logp=celoss(input2,target)
>>> torch.exp(-logp)
tensor(0.9513)

这样一看,虽然反不回去,但是总体还是能体现概率.0.9513接近(0.9734,0.9241,0.9569)三者均值附近,而0.6827也在(0.7844,0.4318,0.9396)三者均值附近.好像是预测越准确越接近均值.但一个值替代不了矩阵中6个值.

猜你喜欢

转载自blog.csdn.net/qq_36401512/article/details/91491205
今日推荐