Pytorch's gather function

torch.gather(input, dim, index, out=None) → Tensor official format

import torch

a = torch.Tensor([[1,2,3],[4,5,6]])
print(a)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,1]])

print(a.gather(1,index_1))
print(a.gather(0,index_2))
####输出结果
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 2.],
        [6., 4.]])
tensor([[1., 5., 6.],
        [1., 2., 6.]])

Simply put, it is to extract the value corresponding to the target according to the index in gather

  • First of all, the dimension of the output matrix is ​​consistent with the dimension of this index.
  • As can be seen from the code, the value of gather mainly depends on the value of dim in front of the index. dim =1, indicating the index column number, which is the horizontal value. The first row of index_1 is [0,1], column 0 refers to 1, column 1 refers to 2 in a, and the second row [ 2,0], 2 column refers to the corresponding 6 in a, and 0 column only corresponds to 4. (Note that the position of the row and column of the output matrix element corresponds to the position of the index element, so that it will not be confused)
  • When dim =0, it indicates the index row number, that is, the vertical value, the first row of index-2 [0,1,1], the 0th row refers to the 1 in a, and the 1st row refers to the a 5, the next first line refers to the 6 in a. The second row of index-2 is [0,0,1], row 0 refers to the 1 in a, the next row 0 refers to the corresponding 2, and row 1 refers to the corresponding 6.
  • Note that this index is a longtensor type, and an error will be reported if it is only for tensor.

Guess you like

Origin blog.csdn.net/weixin_42990464/article/details/112301688