Focal loss
Es una función de pérdida propuesta en el artículo Pérdida focal para la detección de objetos densos para muestras simples decay
. Es Cross Entropy Loss
una mejora con respecto al estándar. FL
Para muestras simples (p es relativamente grande), la respuesta es una pérdida pequeña. Como se muestra en la Figura 1 del artículo, cuando p = 0,6, el estándar es CE
mayor loss
, pero hay una respuesta de pérdida relativamente pequeña a FL. Este es un tipo de decadencia para muestras simples. Entre ellos, alfa está relacionado con la frecuencia de cada categoría en los datos de entrenamiento, pero la siguiente implementación se basa en alfa = 1 para los experimentos.
PyTorch
Para usarlo en Focal Loss
, puede seguir los pasos a continuación
método uno:
1. Cree el archivo FocalLoss.py y agregue el código.
Modificación de código:
classnum
Cámbialo al número de tu clasificación.- P = F.softmax(entradas) cambió a P = F.softmax(entradas, tenue=1)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num=5, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
#print(class_mask)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1,1)
log_p = probs.log()
#print('probs size= {}'.format(probs.size()))
#print(probs)
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
#print('-----bacth_loss------')
#print(batch_loss)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
2. Añade módulos a tu función de formación
from FocalLoss import FocalLoss
loss = FocalLoss()
Método dos:
Primero, asegúrese de haber importado el módulo torch
y torch.nn
, que torch.nn
proporciona varias funciones de pérdida comunes.
import torch
import torch.nn as nn
Luego, defina una clase de pérdida focal personalizada que herede de torch.nn.Module
. En el constructor de la clase, puede especificar los parámetros requeridos por Focal Loss, como γ (factor de ajuste) y peso.
class FocalLoss(nn.Module):
def __init__(self, gamma=2, weight=None):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
def forward(self, inputs, targets):
ce_loss = nn.CrossEntropyLoss(weight=self.weight)(inputs, targets) # 使用交叉熵损失函数计算基础损失
pt = torch.exp(-ce_loss) # 计算预测的概率
focal_loss = (1 - pt) ** self.gamma * ce_loss # 根据Focal Loss公式计算Focal Loss
return focal_loss
A continuación, durante el entrenamiento del modelo, utilice la pérdida focal personalizada en lugar de la función de pérdida de entropía cruzada.
# 定义模型
model = YourModel()
# 定义损失函数(使用自定义的Focal Loss)
criterion = FocalLoss(gamma=2, weight=None)
# 初始化优化器等
# 开始训练循环
for epoch in range(num_epochs):
# 前向传播、计算损失
outputs = model(inputs)
loss = criterion(outputs, targets)
# 反向传播、更新模型参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 其他操作(如打印训练日志等)
Mediante los pasos anteriores, la función de pérdida se puede cambiar de la función de pérdida de entropía cruzada a Focal Loss en PyTorch. Tenga en cuenta que es posible que algunos detalles (como modelo, entrada, optimizador, etc.) en los ejemplos de código anteriores deban modificarse y complementarse de acuerdo con su situación real.