Pytorch中的torch.gather函数的用法

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
# index,index, index
# dim=1,行,水平
# dim=0,列,竖直
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)

结果:

 1  2  3
 4  5  6
[torch.FloatTensor of size 2x3]

torch.gather(b, dim=1, index=index_1)
第一步:dim=1水平方向 1 2 3,4 5 6
第二步:            index, 0,1;2,0 
res:                     [1,2] [6, 4]

 1  2
 6  4
[torch.FloatTensor of size 2x2]


 1  5  6
 1  2  3
[torch.FloatTensor of size 2x3]



猜你喜欢

转载自blog.csdn.net/gz153016/article/details/108937761