Expected object of scalar type Long but got scalar type Double for argument #2 'target'

1.pytorch报错:

loss_class = torch.nn.CrossEntropyLoss()
s_data, s_label = data_source[0].to(DEVICE), data_source[1].to(DEVICE)
class_output, domain_output = model(input_data=s_data.float(), alpha=alpha)
# 报错位置如下:
err_s_label = loss_class(class_output, s_label)

报错内容如下:

Expected object of scalar type Long but got scalar type Double for argument #2 'target'

表示第二个位置的参数要求是Long类型,然而传入的时候是Double类型,因此我们只需:

s_label.long()

即可。

2. 如果会继续出现报错:

RuntimeError: multi-target not supported at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

表示在计算loss过程中遇到了多输出预测值。或者标签的维度是不同的(我这边标签的shape是(128, 1)),我们只需要将标签squeeze就行,具体参考torch.squeeze()函数,个人变动方法:

err_s_label = loss_class(class_output, s_label.squeeze(1).long())

3.总结:torch中数据类型的变化:


 

torch数据类型转换
数据类型 数据长度 用法
int int32 torch.int()
long int64 torch.long()
float float32 torch.float()
double float64 torch.double()

至于需要哪种变化,各位看官,请便。。。

发布了103 篇原创文章 · 获赞 55 · 访问量 14万+

猜你喜欢

转载自blog.csdn.net/l8947943/article/details/103732275
今日推荐