Cómo usa Pytorch la pérdida focal

Focal lossEs una función de pérdida propuesta en el artículo Pérdida focal para la detección de objetos densos para muestras simples decay. Es Cross Entropy Lossuna mejora con respecto al estándar. FLPara muestras simples (p es relativamente grande), la respuesta es una pérdida pequeña. Como se muestra en la Figura 1 del artículo, cuando p = 0,6, el estándar es CEmayor loss, pero hay una respuesta de pérdida relativamente pequeña a FL. Este es un tipo de decadencia para muestras simples. Entre ellos, alfa está relacionado con la frecuencia de cada categoría en los datos de entrenamiento, pero la siguiente implementación se basa en alfa = 1 para los experimentos.

inserte la descripción de la imagen aquí

PyTorchPara usarlo en Focal Loss, puede seguir los pasos a continuación

método uno:

1. Cree el archivo FocalLoss.py y agregue el código.

inserte la descripción de la imagen aquí

Modificación de código:

  • classnumCámbialo al número de tu clasificación.
  • P = F.softmax(entradas) cambió a P = F.softmax(entradas, tenue=1)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num=5, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        #print(class_mask)


        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)


        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

2. Añade módulos a tu función de formación

from FocalLoss import FocalLoss

loss = FocalLoss()

Método dos:

Primero, asegúrese de haber importado el módulo torchy torch.nn, que torch.nnproporciona varias funciones de pérdida comunes.

import torch
import torch.nn as nn

Luego, defina una clase de pérdida focal personalizada que herede de torch.nn.Module. En el constructor de la clase, puede especificar los parámetros requeridos por Focal Loss, como γ (factor de ajuste) y peso.

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss(weight=self.weight)(inputs, targets)  # 使用交叉熵损失函数计算基础损失
        pt = torch.exp(-ce_loss)  # 计算预测的概率
        focal_loss = (1 - pt) ** self.gamma * ce_loss  # 根据Focal Loss公式计算Focal Loss
        return focal_loss

A continuación, durante el entrenamiento del modelo, utilice la pérdida focal personalizada en lugar de la función de pérdida de entropía cruzada.

# 定义模型
model = YourModel()

# 定义损失函数(使用自定义的Focal Loss)
criterion = FocalLoss(gamma=2, weight=None)

# 初始化优化器等

# 开始训练循环
for epoch in range(num_epochs):
    # 前向传播、计算损失
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # 反向传播、更新模型参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 其他操作(如打印训练日志等)

Mediante los pasos anteriores, la función de pérdida se puede cambiar de la función de pérdida de entropía cruzada a Focal Loss en PyTorch. Tenga en cuenta que es posible que algunos detalles (como modelo, entrada, optimizador, etc.) en los ejemplos de código anteriores deban modificarse y complementarse de acuerdo con su situación real.

Supongo que te gusta

Origin blog.csdn.net/weixin_45277161/article/details/132626946
Recomendado
Clasificación