对于pytorch doc中的gather函数的理解,在官方文档中给的描述如下:
torch.gather(input,dim,index,out=None):
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
If input is an n-dimensional tensor with size (x0,x1...,xi−1,xi,xi+1,...,xn−1) and dim =i, then index must be an n-dimensional tensor with size (x0,x1,...,xi−1,y,xi+1,...,xn−1) where y≥1 and out will have the same size as index.
Parameters:
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
out (Tensor, optional) – the destination tensor
Example:
>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1, 1],
[ 4, 3]])
gather中的参数理解如下:
- input:需要操作的tensor,也就是源数据
- dim:选择的维度,dim=0代表纵向选择,为1代表横向选择
- index:相应的元素索引,其规模也就是最后输出的规模,其size应和源tensor相同
- out:输出到某个指定的tensor
example解释:
从上可见源tensor为t=torch.tensor([[1,2],[3,4]]),dim=1代表index为横向的序号,相应位置的元素应该是对应行或列的某个元素,而该索引也即是index对应位置数值指明的。
最后输出的tensor形式为2*2型,因此第一行第0个位置为源数据第一行的第0个也即为1,第一行第1个位置为源数据第一行的第0个位置也就是1,第二行的第0个位置也即是源数据第二行第1个位置为4,第二行的第1个位置也即为源数据第二行的第0个位置为3,因此最后的gather结果为[[1,1],[4,3]]
如果dim改为0,那上例的结果应该为[[1,2],[3,2]].