Esto es esencialmente lo mismo que nn.BCELoss(), pero se agrega una función logits (es decir, la función sigmoide) a BCELoss. El ejemplo es el siguiente:
import torch
import torch.nn as nn
label = torch.Tensor([1, 1, 0])
pred = torch.Tensor([3, 2, 1])
pred_sig = torch.sigmoid(pred)
loss = nn.BCELoss()
print(loss(pred_sig, label))
loss = nn.BCEWithLogitsLoss()
print(loss(pred, label))
loss = nn.BCEWithLogitsLoss()
print(loss(pred_sig, label))
Los resultados de salida son:
tensor(0.4963)
tensor(0.4963)
tensor(0.5990)
Se puede ver que nn.BCEWithLogitsLoss() es equivalente a hacer un sigmoide basado en el resultado de predicción predicho en nn.BCELoss(), y luego continúa calculando la pérdida normalmente. Entonces, esto implica un error bastante extraño. Si la red misma ya usó sigmoid para procesar el resultado de salida y usó nn.BCEWithLogitsLoss() al calcular la pérdida... Entonces será equivalente a calcular el sigmoide dos veces para obtener el resultado previsto. Puede haber todo tipo de problemas extraños.
Por ejemplo, la red no puede converger (tear cat cat head.jpg)
Árbitro
[1] https://zhuanlan.zhihu.com/p/170558960