pytorch-gather

函数torch.gather(inputdimindexout=Nonesparse_grad=False) → Tensor
沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.

  • input (Tensor) – 原张量
  • dim (int) – 索引的轴,二维中dim=0代表以每一列为独立个体,对列中元素进行索引排序,dim=1代表以每一行为独立个体,对行中元素进行索引排序。
  • index (LongTensor) – 索引
  • out (Tensor, optional) – 目标张量
  • sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.(没用过)
b = torch.Tensor([[1,2,3],[4,5,6]])
print(b)
index_1 = torch.LongTensor([[0,1,0,1],[2,0,2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0],[1,1,1]])
print (torch.gather(b, dim=1, index=index_1))#以每一行为独立个体,对行中元素进行索引排序,所以索引表index的行数需要等于原矩阵的行数,对每列中的每个元素进行编号。
print (torch.gather(b, dim=0, index=index_2))#以每一列为独立个体,对列中元素进行索引排序,所以索引表index的列数需要等于原矩阵的列数,对每行中的每个元素进行编号。

上述输出为:

tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 2., 1., 2.],
        [6., 4., 6., 4.]])
tensor([[1., 5., 6.],
        [1., 2., 3.],
        [4., 5., 6.]])

官方文档,三维举例:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Example:

>>> t = torch.tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]])) tensor([[ 1, 1],  [ 4, 3]])

三维情况说明:

三维中: dim=0代表以每一个小channel为独立个体(一共有行x列个),对每个channel中的元素进行索引排序。 index可以有很多个channel但是行数、列数需要等于原矩阵的函数列数,否则,会超出索引范围。                              

               dim=1代表以每一列为独立个体,对列中元素进行索引排序,

               dim=2代表以每一行为独立个体,对行中元素进行索引排序。

a = torch.randint(0, 30, (2, 3, 5))
print(a)
index = torch.LongTensor([[[0,1,2,0,2],
                           [0,0,0,0,0],
                           [1,1,1,1,1]],
                          [[1,2,2,2,2],
                           [0,0,0,0,0],
                           [2,2,2,2,2]]])

index2 = torch.LongTensor([[[0,1,1,0,1],
                            [0,1,1,1,1],
                            [1,1,1,1,1]],
                           [[1,0,0,0,0],
                            [0,0,0,0,0],
                            [1,1,0,0,0]],
                           [[1,0,0,0,0],
                            [0,0,0,0,0],
                            [1,1,0,0,0]]])

# dim=0
b = torch.gather(a,0,index2)
print("dim=0:\n",b)

#dim=1
c = torch.gather(a,1,index)
print("dim=1:\n",c)

#dim=2
d = torch.gather(a,2,index)
print("dim=2:\n",d)

输出:

tensor([[[26,  5, 16,  8,  8],
         [22,  7,  9, 27, 12],
         [25, 10,  7,  6,  4]],

        [[ 4, 11,  2,  2,  2],
         [12,  0, 21, 13,  7],
         [ 2, 20, 13, 26,  2]]])
dim=0:
 tensor([[[26, 11,  2,  8,  2],
         [22,  0, 21, 13,  7],
         [ 2, 20, 13, 26,  2]],

        [[ 4,  5, 16,  8,  8],
         [22,  7,  9, 27, 12],
         [ 2, 20,  7,  6,  4]],

        [[ 4,  5, 16,  8,  8],
         [22,  7,  9, 27, 12],
         [ 2, 20,  7,  6,  4]]])
dim=1:
 tensor([[[26,  7,  7,  8,  4],
         [26,  5, 16,  8,  8],
         [22,  7,  9, 27, 12]],

        [[12, 20, 13, 26,  2],
         [ 4, 11,  2,  2,  2],
         [ 2, 20, 13, 26,  2]]])
dim=2:
 tensor([[[26,  5, 16, 26, 16],
         [22, 22, 22, 22, 22],
         [10, 10, 10, 10, 10]],

        [[11,  2,  2,  2,  2],
         [12, 12, 12, 12, 12],
         [13, 13, 13, 13, 13]]])
dim = 0的时候(三维情况下),举的例子只有2 channels。所以index在0,1两个之间选择。 输出的矩阵元素也是按照index的指定。分别在1st channel和2nd channel之间选。 index [0,1,1,0,1]的分别代表第一个元素在1st channel选,第二个元素在2nd channel选,第三个元素在2nd channel选,第四个元素在1st channel选,第五个元素在2nd channel选。
 

猜你喜欢

转载自www.cnblogs.com/oliyoung/p/pytorch-gather.html