Función de recopilación de Pytorch

torch.gather (input, dim, index, out = None) → Formato oficial del tensor

import torch

a = torch.Tensor([[1,2,3],[4,5,6]])
print(a)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,1]])

print(a.gather(1,index_1))
print(a.gather(0,index_2))
####输出结果
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 2.],
        [6., 4.]])
tensor([[1., 5., 6.],
        [1., 2., 6.]])

En pocas palabras, es extraer el valor correspondiente al objetivo de acuerdo con el índice en reunir

  • En primer lugar, la dimensión de la matriz de salida es coherente con la dimensión de este índice.
  • Como se puede ver en el código, el valor de recopilación depende principalmente del valor de atenuación delante del índice. dim = 1, que indica el número de columna de índice, que es el valor horizontal. La primera fila de index_1 es [0,1], la columna 0 se refiere a 1, la columna 1 se refiere a 2 en a, y la segunda fila [2,0 ], La columna 2 se refiere al 6 correspondiente en a, y la columna 0 solo corresponde a 4. (Tenga en cuenta que la posición de la fila y la columna del elemento de la matriz de salida corresponde a la posición del elemento de índice, por lo que no se confundirá)
  • Cuando dim = 0, indica el número de la fila del índice, es decir, el valor vertical. La primera fila de index-2 [0,1,1], la fila 0 se refiere al 1 en a, y la 1ª fila se refiere a el a 5, la siguiente primera línea se refiere al 6 en a. La segunda fila del índice 2 es [0,0,1], la fila 0 se refiere al 1 en a, la siguiente fila 0 se refiere al 2 correspondiente y la fila 1 se refiere al 6 correspondiente.
  • Tenga en cuenta que este índice es de tipo longtensor y se informará un error si es solo para el tensor.

Supongo que te gusta

Origin blog.csdn.net/weixin_42990464/article/details/112301688
Recomendado
Clasificación