分类中几种损失函数的内容解读

1.LabelSmoothingCrossentropy损失函数

LabelSmoothingCrossentropy平滑损失函数的内容解读

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, eps=0.1, reduction='mean', ignore_index=-100):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.eps = eps
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, output, target):
        c = output.size()[-1]
        log_pred = torch.log_softmax(output, dim=-1)
        if self.reduction == 'sum':
            loss = -log_pred.sum()
        else:
            loss = -log_pred.sum(dim=-1)
            if self.reduction == 'mean':
                loss = loss.mean()
        return loss * self.eps / c + (1 - self.eps) * torch.nn.functional.nll_loss(log_pred, target,
                                                                                   reduction=self.reduction,
                                                                                   ignore_index=self.ignore_index)

首先分析这里的nll_loss的对应内容

inputs = torch.tensor([[-0.1342,-2.5835,-0.9810],
                       [0.1867,-1.4513,-0.3225],
                       [0.6272,-0.1120,0.3048]])
target = torch.tensor([0,2,1])
loss = nn.NLLLoss(reduction='mean',ignore_index=-100)
result = loss(inputs,target)
print(result)

这里的result = 0.189566667,具体的计算过程为
0对应的-0.1342,2对应的-0.3225,1对应的-0.1120,所以相加后最终结果
( 0.1342 + 0.3225 + 0.1120 ) / 3 = 0.189566667 (0.1342+0.3225+0.1120)/3 = 0.189566667 (0.1342+0.3225+0.1120)/3=0.189566667
假设输入的内容为

output = torch.tensor([[0.2,0.2,0.6],[0.1,0.1,0.8]])
target = torch.tensor([1,2])
loss = LabelSmoothingCrossEntropy()
loss(output,target)

这里面输入之后
c = 3,

log_pred = torch.log_softmax(output,dim=-1)

得到

log_pred = 
tensor([[-1.2504,-1.2504,-0.8504],
[-1.3897,-1.3897,-0.6897]])

接着进入求平均损失的过程

if self.reduction == 'sum':
	loss = -log_pred.sum()
else:
	loss = -log_pred.sum(dim=-1)
	#loss = tensor([3.3513,3.4692])
	if self.reduction == 'mean':
		loss = loss.mean()
		#mean loss = (3.4102)
return self.eps*loss/c+(1-self.eps)*torch.nn.functional.nll_loss(log_pred, target,
                                                              reduction=self.reduction,
                                                              ignore_index=self.ignore_index)

也就是说,这里使用0.1的部分计算常规的log_softmax损失,使用0.9权重部分计算nll_loss的损失

2.FocalLoss损失函数(效果还不错)

import torch
import torch.nn as nn
class FocalLoss(nn.Module):
    """Multi-class Focal loss implementation"""
    def __init__(self, gamma=2, weight=None, reduction='mean', ignore_index=-100):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, input, target):
        """
        input: [N, C]
        target: [N, ]
        """
        log_pt = torch.log_softmax(input, dim=1)
        pt = torch.exp(log_pt)
        log_pt = (1 - pt) ** self.gamma * log_pt
        loss = torch.nn.functional.nll_loss(log_pt, target, self.weight, reduction=self.reduction, ignore_index=self.ignore_index)
        return loss

对应的公式
i n p u t = ( 1 − e l o g _ s o f t m a x ( i n p u t , d i m = 1 ) ) 2 ∗ e l o g _ s o f t m a x ( i n p u t _ d i m = 1 ) input = (1-e^{log\_softmax(input,dim=1)})^{2}*e^{log\_softmax(input\_dim=1)} input=(1elog_softmax(input,dim=1))2elog_softmax(input_dim=1)
然后计算新的input与target之间的nll_loss的损失值

猜你喜欢

转载自blog.csdn.net/znevegiveup1/article/details/120323924