torch.gather()函数解析

torch.gather()函数解析

在使用过程中,对于gather的具体使用方法不甚了解,经过一番测试之后,终于弄懂了其使用方法,这里进行整理。

首先看一下这个方法的参数:

torch.gather(inputs, dim, index, *, sparse_grad=False, out=None) → Tensor

关键参数为inputs, dim, index,我将会按照自己的理解在后面详细解释每个参数的含义,这里只需要有一个简单的印象即可。

官方解释

For a 3-D tensor the output is specified by:

out[i][j][k] = inputs[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = inputs[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = inputs[i][j][index[i][j][k]]  # if dim == 2

其实这里的式子以及解释的很清楚了,奈何本人愚钝,很久才理解…

这个式子具体的含义其实很简单,就是对于给定的inputs,将参数dim所指定维度的索引值用参数index中相应的元素代替表示(因此index的维度必须和inputs相等,因为需要索引到inputs中每个元素),其余位置的索引按照顺序即0,1,2…

这么说可能比较模糊,我这里给出一个具体的例子:

import torch

inputs = torch.rand(2,2,2)
inputs, inputs.shape
# (tensor([[[0.2465, 0.5183],
#          [0.7851, 0.9985]],
 
#         [[0.8217, 0.2534],
#          [0.0478, 0.4805]]]),
# torch.Size([2, 2, 2]))

index = torch.tensor([[[1, 0],[0, 0]]])
index, index.shape
# (tensor([[[1, 0],
#          [0, 0]]]),
# torch.Size([1, 2, 2]))

这里的inputs的索引,可以表示为inputs[i][j][k],例如inputs[0][1][1]=0.9985

则inputs中的每个元素的索引分别为:

inputs[0][0][0], inputs[0][0][1]
inputs[0][1][0], inputs[0][1][1]

inputs[1][0][0], inputs[1][0][1]
inputs[1][1][0], inputs[1][1][1]

当我们指定dim=0时,调用gather()方法:

inputs.gather(0, index)
# tensor([[[0.4327, 0.2616],
#         [0.7656, 0.1954]]])

也就是说,对于inputs的第0维,我们将其索引i用张量index中的每个元素代替

例如第一个结果0.04327,我们知道它在结果矩阵的第一个位置上,因此它的j=0,k=0,原本该位置的i=0,经过gather()方法之后,替换为i=index[0][0][0],而index[0][0][0]=1,故该位置的元素为源矩阵[1][0][0]位置的元素,inputs[1][0][0]=0.4327

同理可得到每个位置的索引分别为:

inputs[1][0][0], inputs[0][0][1]
inputs[0][1][0], inputs[0][1][1],

也就是说,对于inputsdim=0位置的索引i,我们将其原本的顺序替换为index参数中的值。
我们再试一下令dim=1

inputs.gather(1, index)
# tensor([[[0.7656, 0.2616],
#         [0.2183, 0.2616]]])

每个位置的索引即为:

inputs[0][1][0], inputs[0][0][1]
inputs[0][0][0], inputs[0][0][1],

总结

一句话概括gather()方法的作用就是,对于给定的inputs,将参数dim所指定维度的索引值用参数index中相应的元素代替表示。

猜你喜欢

转载自blog.csdn.net/qq_45802280/article/details/127888341