reunir () en Pytorch

Introducción a la función

efecto

Se utiliza para recopilar los valores en el tensor de entrada de la dimensión especificada

parámetro

​torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
  • entrada  ( Tensor ) – tensor de entrada

  • dim  ( int ) – el eje utilizado para indexar valores

  • índice  ( LongTensor ) – el valor del índice

  • sparse_grad  ( bool,  opcional ): si es verdadero, el gradiente del tensor de entrada se convertirá en un tensor disperso

  • out  ( Tensor,  opcional ) – el tensor de salida

Precauciones

la entrada y el índice deben tener las mismas dimensiones. Si d != dim, también requiere que index.size(d) <= input.size(d) para todas las dimensiones. la salida tiene la misma forma que el índice

Ejemplo de tensor 2D

tenue=0

  • Primero cree un tensor de entrada con valores del 1 al 16 y remodele
import torch

x = torch.range(1,16).view(4,4)
"""
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
"""
  • Luego crea el índice

Como [[0, 1, 2, 3], [3, 2, 1, 0]]

Primero vea [0, 1, 2, 3], el valor dentro significa seleccionar entre las filas 0, 1, 2 y 3 respectivamente , y luego porque [ 0 , 1, 2, 3] están ubicados respectivamente en el 0, 2da y 3ra filas en las columnas index 1 , 2 , 3 , por lo que la salida después de la indexación es: input[ 0 ][ 0 ] , input[ 1 ][ 1 ] , input[ 2 ][ 2 ] , input[ 3 ] [ 3 ] , a saber [1. , 6., 11., 16.]

Luego vea [3, 2, 1, 0], el valor interior significa seleccionar de las filas 3, 2, 1, 0 respectivamente, y luego porque [3, 2, 1, 0] están ubicados respectivamente en el 0 , 0 y 0th filas en el índice 1 , 2 , 3 columnas, por lo que la salida después de la indexación es: input[ 3 ][ 0 ] , input[ 2 ][ 1 ] , input[ 1 ][ 2 ] , input[ 0 ][ 3 ] , a saber [13. , 10., 7., 4.]

index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
        [3, 2, 1, 0]])
"""
  • Los resultados de la impresión muestran lo mismo que se esperaba
y = torch.gather(x, dim=0, index=index)
"""
tensor([[ 1.,  6., 11., 16.],
        [13., 10.,  7.,  4.]])
"""

tenue=1

  • Crear tensor de entrada
import torch

x = torch.range(1,16).view(4,4)
"""
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
"""
  • crear índice

Como [[0, 1, 2, 3], [3, 2, 1, 0]]

Primero vea [0, 1, 2, 3], el valor dentro significa seleccionar de las columnas 0, 1, 2 y 3 respectivamente, y luego porque [0, 1, 2 , 3] se encuentra en la fila 0 del índice , Por lo tanto, la salida después de indexar es: input[ 0 ][ 0 ] , input[ 0 ][ 1 ] , input[ 0 ][ 2 ] , input[ 0 ][ 3 ] , a saber [1., 2., 3 ., 4.]

Luego vea [3, 2, 1, 0], el valor dentro significa seleccionar de las columnas 3rd , 2 , 1 , 0 respectivamente, y porque [3, 2, 1, 0] se encuentra en la primera fila del índice , Entonces, la salida después de la indexación es: input[ 1 ][ 3 ] , input[ 1 ][ 2 ] , input[ 1 ][ 1 ] , input[ 1 ][ 0 ] , a saber [8., 7., 6. , 5.]

index = torch.LongTensor([[0, 1, 2, 3], [3, 2, 1, 0]])
"""
tensor([[0, 1, 2, 3],
        [3, 2, 1, 0]])
"""
  • Imprima el tensor, mostrando como se esperaba
y = torch.gather(x, dim=1, index=index)
"""
tensor([[1., 2., 3., 4.],
        [8., 7., 6., 5.]])
"""

Resumir

Al recopilar el tensor 2D, si dim = 0 o 1 , el valor en el índice indica que primero debe seleccionar de una determinada fila o columna , y luego ubicar el valor de acuerdo con la columna o fila en el índice , y puede obtener el valor Recolectar requerido

Supongo que te gusta

Origin blog.csdn.net/qq_38964360/article/details/131550919
Recomendado
Clasificación