Pytorch的gather函数

torch.gather(input, dim, index, out=None) → Tensor 官方给的格式

import torch

a = torch.Tensor([[1,2,3],[4,5,6]])
print(a)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,1]])

print(a.gather(1,index_1))
print(a.gather(0,index_2))
####输出结果
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 2.],
        [6., 4.]])
tensor([[1., 5., 6.],
        [1., 2., 6.]])

简单来说,就是按照gather里面的索引取出目标对应的值

  • 首先,输出矩阵的维度和这个index的维度是一致的。
  • 从代码可以看出,gather取值主要看索引前面的dim取值。dim =1,表明索引列号,也就是横向取值,index_1的第一行[0,1],0列指的是a中的1,1列指的是a中的2,第二行[2,0],2列指的是a中对应的6,0列只的是对应的4。(注意这里输出矩阵元素的行列位置对应index元素行列的位置,这样就不会弄混了)
  • dim =0 时,表明索引行号,即纵向取值,index-2的第一行[0,1,1],第0行指的是a中的1,第1行指的是a中的5 ,下一个第1行指的是a中的6。index-2的第二行[0,0,1],第0行指的是a中的1,下一个第0行指的是对应的2,第1行指的是对应的6。
  • 注意,这个index是一个longtensor的类型,只为tensor的话会报错。

猜你喜欢

转载自blog.csdn.net/weixin_42990464/article/details/112301688