Tutorial de aprendizaje minimalista de PyTorch: el uso y la diferencia de la función de pérdida nn.CrossEntropyLoss () y nn.NLLLoss ()

Ejecute el código y lo sabrá de un vistazo.
Código de muestra:

import torch
import torch.nn as nn
input = torch.randn(3, 3)
print(input)
sm = nn.Softmax(dim=1)
print(sm(input))
test1 = torch.log(sm(input))
print(test1)
print(abs(test1[0][0] + test1[1][2] + test1[2][1]) / 3)

loss = nn.NLLLoss()
target = torch.tensor([0, 2, 1])
l1 = loss(test1, target)
print(l1)

loss_1 = nn.CrossEntropyLoss()
print(loss_1(input, target))

resultado:

tensor([[-0.1215, -0.8342,  0.4117],
        [-1.0425,  0.4401,  1.0196],
        [ 0.3323, -0.0629,  0.0515]])
tensor([[0.3130, 0.1535, 0.5335],
        [0.0754, 0.3320, 0.5926],
        [0.4117, 0.2773, 0.3109]])
tensor([[-1.1615, -1.8742, -0.6283],
        [-2.5852, -1.1027, -0.5232],
        [-0.8873, -1.2825, -1.1682]])
tensor(0.9891)
tensor(0.9891)
tensor(0.9891)

Supongo que te gusta

Origin blog.csdn.net/qq_28057379/article/details/106815452
Recomendado
Clasificación