torch.nn.BCEWithLogitsLoss与torch.nn.BCELoss

torch.nn.BCEWithLogitsLoss相当于sigmoid+torch.nn.BCELoss。代码示例如下,

import torch
import torch.nn as nn


BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
BCELoss = nn.BCELoss()

x = torch.randn((1,))
y = torch.FloatTensor([1])

Loss_BCEWithLogits = BCEWithLogitsLoss(x, y)
Loss_BCE = BCELoss(torch.sigmoid(x), y)

print("BCEWithLogitsLoss:", Loss_BCEWithLogits)
print("BCELoss:", Loss_BCE)


"""
BCEWithLogitsLoss: tensor(0.2138)
BCELoss: tensor(0.2138)
"""

猜你喜欢

转载自blog.csdn.net/qq_38964360/article/details/131696071