pytorch: “multi-target not supported” error message

BUG

使用Cross_entropy损失函数时出现 RuntimeError: multi-target not supported at …

可能存在的问题

1)其标签必须为0~n-1,而且必须为1维的,如果设置标签为[nx1]的,则也会出现以上错误。
2)标签y打印:
tensor([[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0]], device='cuda:0')
对于n.CrossEntroyLoss,目标必须是间隔[0,#class]的单个数字,而不是一个热编码的目标向量。您的目标是[1,0],因此PyTorch认为您希望每个输入都有多个不受支持的标签。
替换您的one-hot编码标签:

[1, 0] --> 0

[0, 1] --> 1

输入以下代码可解决上述问题。

y = torch.argmax(y, dim=1)
发布了33 篇原创文章 · 获赞 3 · 访问量 5542

猜你喜欢

转载自blog.csdn.net/weixin_42990464/article/details/104523499