pytorch中的gather函数的含义

对于pytorch doc中的gather函数的理解,在官方文档中给的描述如下:

torch.gather(input,dim,index,out=None):

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
If input is an n-dimensional tensor with size (x0,x1...,xi−1,xi,xi+1,...,xn−1) and dim =i, then index must be an n-dimensional tensor with size (x0,x1,...,xi−1,y,xi+1,...,xn−1) where y≥1 and out will have the same size as index.

Parameters:	
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
out (Tensor, optional) – the destination tensor

Example:

>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1,  1],
        [ 4,  3]])

gather中的参数理解如下:

  • input:需要操作的tensor,也就是源数据
  • dim:选择的维度,dim=0代表纵向选择,为1代表横向选择
  • index:相应的元素索引,其规模也就是最后输出的规模,其size应和源tensor相同
  • out:输出到某个指定的tensor

example解释:

从上可见源tensor为t=torch.tensor([[1,2],[3,4]]),dim=1代表index为横向的序号,相应位置的元素应该是对应行或列的某个元素,而该索引也即是index对应位置数值指明的。

最后输出的tensor形式为2*2型,因此第一行第0个位置为源数据第一行的第0个也即为1,第一行第1个位置为源数据第一行的第0个位置也就是1,第二行的第0个位置也即是源数据第二行第1个位置为4,第二行的第1个位置也即为源数据第二行的第0个位置为3,因此最后的gather结果为[[1,1],[4,3]]

如果dim改为0,那上例的结果应该为[[1,2],[3,2]].

猜你喜欢

转载自blog.csdn.net/pro_misefetion/article/details/84325261