IndexError: tensors used as indices must be long, byte or bool tensors

下面的程序会报错IndexError: tensors used as indices must be long, byte or bool tensors

mask = torch.Tensor([True,True,False])
a = torch.Tensor([3,2,1])
a[mask]=0
print(a)

原因是索引要为long, byte 或 bool类型,因此需要将mask转换为bool类型,如下:

mask = torch.Tensor([True,True,False]).type(torch.bool)
a = torch.Tensor([3,2,1])
a[mask]=0
print(a)
tensor([0., 0., 1.])

猜你喜欢

转载自blog.csdn.net/weixin_38314865/article/details/105949825