函数介绍
作用
用于从指定维度gather输入张量中的数值
参数
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
-
input (Tensor) – 输入张量
-
dim (int) – 用于索引取值的轴
-
index (LongTensor) – 索引值
-
sparse_grad (bool, optional) – 若为True,输入张量的梯度会变成稀疏张量
-
out (Tensor, optional) – 输出张量
注意事项
input和index必须具有相同的维数。如果d != dim,还要求所有维度的index.size(d) <= input.size(d)。output与index的形状相同
2D-Tensor示例
dim=0
- 首先创建一个数值从1到16的输入张量并reshape
import torch
x = torch.range(1,16).view(4,4)
"""
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
"""
- 接着创建index
如[[0, 1, 2, 3], [3, 2, 1, 0]]
首先看到[0, 1, 2, 3],里面的数值表示分别从第0、1、2、3行进行选择,然后因为[0, 1, 2, 3]分别位于index中的第0、1、2、3列,因此索引后的输出为:input[0][0]、input[1][1]、input[2][2]、input[3][3],即[1., 6., 11., 16.]
接着看到[3, 2, 1, 0],里面的数值表示分别从第3、2、1、0行进行选择,然后因为[3, 2, 1, 0]分别位于index中的第0、1、2、3列,因此索引后的输出为:input[3][0]、input[2][1]、input[1][2]、input[0][3],即[13., 10., 7., 4.]
index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
[3, 2, 1, 0]])
"""
- 打印输出结果,显示与预期的一致
y = torch.gather(x, dim=0, index=index)
"""
tensor([[ 1., 6., 11., 16.],
[13., 10., 7., 4.]])
"""
dim=1
- 创建输入张量
import torch
x = torch.range(1,16).view(4,4)
"""
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
"""
- 创建index
如[[0, 1, 2, 3], [3, 2, 1, 0]]
首先看到[0, 1, 2, 3],里面的数值表示分别从第0、1、2、3列进行选择,然后因为[0, 1, 2, 3]位于index中的第0行,因此索引后的输出为:input[0][0]、input[0][1]、input[0][2]、input[0][3],即[1., 2., 3., 4.]
接着看到[3, 2, 1, 0],里面的数值表示分别从第3、2、1、0列进行选择,然后因为[3, 2, 1, 0]位于index中的第1行,因此索引后的输出为:input[1][3]、input[1][2]、input[1][1]、input[1][0],即[8., 7., 6., 5.]
index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
[3, 2, 1, 0]])
"""
- 打印输出张量,显示与预期的一致
y = torch.gather(x, dim=1, index=index)
"""
tensor([[1., 2., 3., 4.],
[8., 7., 6., 5.]])
"""
总结
对2D-tensor进行gather时,若dim=0 or 1,则index中的数值表示首先应从某行 or 列进行选择,再根据该数值在index中所处的列 or 行进行定位,便可得到需要gather的数值