Pytorch中的gather()

函数介绍

作用

用于从指定维度gather输入张量中的数值

参数

​torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
  • input (Tensor) – 输入张量

  • dim (int) – 用于索引取值的轴

  • index (LongTensor) – 索引值

  • sparse_grad (bool, optional) – 若为True,输入张量的梯度会变成稀疏张量

  • out (Tensor, optional) – 输出张量

注意事项

input和index必须具有相同的维数。如果d != dim,还要求所有维度的index.size(d) <= input.size(d)。output与index的形状相同

2D-Tensor示例

dim=0

  • 首先创建一个数值从1到16的输入张量并reshape
import torch

x = torch.range(1,16).view(4,4)
"""
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
"""
  • 接着创建index

如[[0, 1, 2, 3], [3, 2, 1, 0]]

首先看到[0, 1, 2, 3],里面的数值表示分别从第0123行进行选择,然后因为[0, 1, 2, 3]分别位于index中的第0123列,因此索引后的输出为:input[0][0]input[1][1]input[2][2]input[3][3],即[1.,  6.,  11.,  16.]

接着看到[3, 2, 1, 0],里面的数值表示分别从第3210行进行选择,然后因为[3, 2, 1, 0]分别位于index中的第0123列,因此索引后的输出为:input[3][0]input[2][1]input[1][2]input[0][3],即[13.,  10.,  7.,  4.]

index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
        [3, 2, 1, 0]])
"""
  • 打印输出结果,显示与预期的一致
y = torch.gather(x, dim=0, index=index)
"""
tensor([[ 1.,  6., 11., 16.],
        [13., 10.,  7.,  4.]])
"""

dim=1

  • 创建输入张量
import torch

x = torch.range(1,16).view(4,4)
"""
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
"""
  • 创建index

如[[0, 1, 2, 3], [3, 2, 1, 0]]

首先看到[0, 1, 2, 3],里面的数值表示分别从第0123列进行选择,然后因为[0, 1, 2, 3]位于index中的第0行,因此索引后的输出为:input[0][0]input[0][1]input[0][2]input[0][3],即[1.,  2.,  3.,  4.]

接着看到[3, 2, 1, 0],里面的数值表示分别从第3210列进行选择,然后因为[3, 2, 1, 0]位于index中的第1行,因此索引后的输出为:input[1][3]input[1][2]input[1][1]input[1][0],即[8.,  7.,  6.,  5.]

index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
        [3, 2, 1, 0]])
"""
  • 打印输出张量,显示与预期的一致
y = torch.gather(x, dim=1, index=index)
"""
tensor([[1., 2., 3., 4.],
        [8., 7., 6., 5.]])
"""

总结

对2D-tensor进行gather时,若dim=0 or 1,则index中的数值表示首先应从某行 or 列进行选择,再根据该数值在index中所处的列 or 行进行定位,便可得到需要gather的数值

猜你喜欢

转载自blog.csdn.net/qq_38964360/article/details/131550919