版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jacke121/article/details/83042423
数据类型只支持long类型
import torch
import torch.nn as nn
loss = nn.CrossEntropyLoss()
# input, NxC=2x3
input = torch.randn(2, 3, requires_grad=True)
# target, N
target = torch.empty(2, dtype=torch.long).random_(3)
output = loss(input, target)
output.backward()
---------------------
作者:AIHGF
来源:CSDN
原文:https://blog.csdn.net/zziahgf/article/details/80196376?utm_source=copy
版权声明:本文为博主原创文章,转载请附上博文链接!