unable to get repr for class 'torch.tensor'

unable to get repr for class 'torch.tensor'

出错代码:

batch_conf.gather(1, conf_t.view(-1,1))

最近码代码使用pytorch遇到如题所示的问题,查遍Google百度,大多是说运算时维度不符,但是我找遍代码也没发现有这个错误。一段时间后才发现,网络参数保存的是torch.float32类型,而我输入的数据是torch.float64类型,将数据类型更改为torch.float32,问题解决。
我是因为是用别人的训练代码,没有改完,除了bug,导致最后输出的神经元个数(类别数)小于给的label-1(从0开始)的值。必须是神经元个数即类别数要完全等于maximum label value-1,比如分成10类,label最大只能是9,超过9的情况出现就会出现题目中的错误,然后pytorch还没有提示。。。
 

发布了2853 篇原创文章 · 获赞 1112 · 访问量 581万+

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/105449473