下面的程序会报错IndexError: tensors used as indices must be long, byte or bool tensors
mask = torch.Tensor([True,True,False])
a = torch.Tensor([3,2,1])
a[mask]=0
print(a)
原因是索引要为long, byte 或 bool类型,因此需要将mask转换为bool类型,如下:
mask = torch.Tensor([True,True,False]).type(torch.bool)
a = torch.Tensor([3,2,1])
a[mask]=0
print(a)
tensor([0., 0., 1.])