b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
# index,index, index
# dim=1,行,水平
# dim=0,列,竖直
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)
结果:
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
torch.gather(b, dim=1, index=index_1)
第一步:dim=1水平方向 1 2 3,4 5 6
第二步: index, 0,1;2,0
res: [1,2] [6, 4]
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
Pytorch中的torch.gather函数的用法
猜你喜欢
转载自blog.csdn.net/gz153016/article/details/108937761
今日推荐
周排行