Comprensión y uso sencillos de Pytorch nn.BCEWithLogitsLoss()

Esto es esencialmente lo mismo que nn.BCELoss(), pero se agrega una función logits (es decir, la función sigmoide) a BCELoss. El ejemplo es el siguiente:

import torch
import torch.nn as nn

label = torch.Tensor([1, 1, 0])
pred = torch.Tensor([3, 2, 1])
pred_sig = torch.sigmoid(pred)
loss = nn.BCELoss()
print(loss(pred_sig, label))

loss = nn.BCEWithLogitsLoss()
print(loss(pred, label))

loss = nn.BCEWithLogitsLoss()
print(loss(pred_sig, label))

Los resultados de salida son:

tensor(0.4963)
tensor(0.4963)
tensor(0.5990)

Se puede ver que nn.BCEWithLogitsLoss() es equivalente a hacer un sigmoide basado en el resultado de predicción predicho en nn.BCELoss(), y luego continúa calculando la pérdida normalmente. Entonces, esto implica un error bastante extraño. Si la red misma ya usó sigmoid para procesar el resultado de salida y usó nn.BCEWithLogitsLoss() al calcular la pérdida... Entonces será equivalente a calcular el sigmoide dos veces para obtener el resultado previsto. Puede haber todo tipo de problemas extraños.

Por ejemplo, la red no puede converger (tear cat cat head.jpg)

Árbitro

[1] https://zhuanlan.zhihu.com/p/170558960

Supongo que te gusta

Origin blog.csdn.net/qq_40714949/article/details/120295651
Recomendado
Clasificación