Introducción a la función
efecto
Se utiliza para recopilar los valores en el tensor de entrada de la dimensión especificada
parámetro
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
-
entrada ( Tensor ) – tensor de entrada
-
dim ( int ) – el eje utilizado para indexar valores
-
índice ( LongTensor ) – el valor del índice
-
sparse_grad ( bool, opcional ): si es verdadero, el gradiente del tensor de entrada se convertirá en un tensor disperso
-
out ( Tensor, opcional ) – el tensor de salida
Precauciones
la entrada y el índice deben tener las mismas dimensiones. Si d != dim, también requiere que index.size(d) <= input.size(d) para todas las dimensiones. la salida tiene la misma forma que el índice
Ejemplo de tensor 2D
tenue=0
- Primero cree un tensor de entrada con valores del 1 al 16 y remodele
import torch
x = torch.range(1,16).view(4,4)
"""
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
"""
- Luego crea el índice
Como [[0, 1, 2, 3], [3, 2, 1, 0]]
Primero vea [0, 1, 2, 3], el valor dentro significa seleccionar entre las filas 0, 1, 2 y 3 respectivamente , y luego porque [ 0 , 1, 2, 3] están ubicados respectivamente en el 0, 2da y 3ra filas en las columnas index 1 , 2 , 3 , por lo que la salida después de la indexación es: input[ 0 ][ 0 ] , input[ 1 ][ 1 ] , input[ 2 ][ 2 ] , input[ 3 ] [ 3 ] , a saber [1. , 6., 11., 16.]
Luego vea [3, 2, 1, 0], el valor interior significa seleccionar de las filas 3, 2, 1, 0 respectivamente, y luego porque [3, 2, 1, 0] están ubicados respectivamente en el 0 , 0 y 0th filas en el índice 1 , 2 , 3 columnas, por lo que la salida después de la indexación es: input[ 3 ][ 0 ] , input[ 2 ][ 1 ] , input[ 1 ][ 2 ] , input[ 0 ][ 3 ] , a saber [13. , 10., 7., 4.]
index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
[3, 2, 1, 0]])
"""
- Los resultados de la impresión muestran lo mismo que se esperaba
y = torch.gather(x, dim=0, index=index)
"""
tensor([[ 1., 6., 11., 16.],
[13., 10., 7., 4.]])
"""
tenue=1
- Crear tensor de entrada
import torch
x = torch.range(1,16).view(4,4)
"""
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
"""
- crear índice
Como [[0, 1, 2, 3], [3, 2, 1, 0]]
Primero vea [0, 1, 2, 3], el valor dentro significa seleccionar de las columnas 0, 1, 2 y 3 respectivamente, y luego porque [0, 1, 2 , 3] se encuentra en la fila 0 del índice , Por lo tanto, la salida después de indexar es: input[ 0 ][ 0 ] , input[ 0 ][ 1 ] , input[ 0 ][ 2 ] , input[ 0 ][ 3 ] , a saber [1., 2., 3 ., 4.]
Luego vea [3, 2, 1, 0], el valor dentro significa seleccionar de las columnas 3rd , 2 , 1 , 0 respectivamente, y porque [3, 2, 1, 0] se encuentra en la primera fila del índice , Entonces, la salida después de la indexación es: input[ 1 ][ 3 ] , input[ 1 ][ 2 ] , input[ 1 ][ 1 ] , input[ 1 ][ 0 ] , a saber [8., 7., 6. , 5.]
index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
[3, 2, 1, 0]])
"""
- Imprima el tensor, mostrando como se esperaba
y = torch.gather(x, dim=1, index=index)
"""
tensor([[1., 2., 3., 4.],
[8., 7., 6., 5.]])
"""
Resumir
Al recopilar el tensor 2D, si dim = 0 o 1 , el valor en el índice indica que primero debe seleccionar de una determinada fila o columna , y luego ubicar el valor de acuerdo con la columna o fila en el índice , y puede obtener el valor Recolectar requerido