La différence entre torch.gather et tf.gather

Dans Tensorflow et PyTorch,

gather La fonction est utilisée pour agréger les éléments du tenseur d'entrée en fonction de l'index spécifié et renvoyer un nouveau tenseur.

Plus précisément, la fonction et l'utilisation de cette fonction sont les suivantes :

Fonctions dans TensorFlow tf.gather() :

tf.gather(params, indices, axis=None, batch_dims=0, name=None)

Parmi eux, params représente le tenseur d'entrée ; indices représente l'index qui doit être agrégé, qui peut être une liste ou un tenseur constant ; axis est un entier spécifiant la dimension d'agrégation, s'il n'est pas transmis, la valeur par défaut est 0 ; est un batch_dims entier spécifiant le nombre de les dimensions du lot, généralement utilisées lors de l'agrégation de plusieurs échantillons ; name sont un nom d'opération facultatif.

Par exemple, pour les tenseurs  params = [[1, 2], [3, 4], [5, 6], [7, 8]]qui nécessitent  [0, 2, 3] une agrégation sur la première dimension par index, vous pouvez utiliser le code suivant :

import tensorflow as tf
params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = [0, 2, 3]
output = tf.gather(params, indices, axis=0)
print(output)
# Output: [[1 2]
#          [5 6]
#          [7 8]]

Fonctions dans PyTorch  torch.gather() :

torch.gather(input, dim, index, out=None)

Parmi eux, input représente le tenseur d'entrée ; dim représente la dimension qui doit être agrégée ; index représente l'index qui doit être agrégé, qui peut être une liste ou un tenseur constant ; out et est un tenseur de sortie facultatif.

Par exemple, pour les tenseurs  input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])qui nécessitent  torch.tensor([[0], [2], [3]]) une agrégation sur la première dimension par index, vous pouvez utiliser le code suivant :

import torch
input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = torch.tensor([[0], [2], [3]])
output = torch.gather(input, 0, indices)
print(output)
# Output: tensor([[1],
#                  [5],
#                  [7]])

——————————————————————————————

différent: 

L'index de tf.gather est destiné à  [0, 2, 3]时,可以l'agrégation dans la première dimension, mais

如果torch.gather的索引也设置为[0, 2, 3],则会报错:

RuntimeError: Index tensor must have the same number of dimensions as input tensor

而若torch.gather的索引设置为[[0], [2], [3]]

就会得到像上面的输出:

# Output: tensor([[1],
#           [5],
#           [7]])

Pour obtenir le même effet, nous utilisons d'abord

La fonction torch.unsqueeze ajoute une dimension au tenseur de [0, 2, 3] à [[0], [2], [3]],
Utilisez ensuite la fonction de répétition pour modifier l'index :
import torch
input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = torch.tensor([0, 2, 3])
indices = torch.unsqueeze(indices, dim=1)
indices = indices.repeat(1, input.size(1))
output = torch.gather(input, 0, indices)
print(output)
# Output: tensor([[1, 2],
#                  [5, 6],
#                  [7, 8]])

La fonction de répétition signifie répéter cet index. Le premier paramètre est 1, ce qui signifie que la ligne reste inchangée. Le deuxième paramètre est

input.size(1) signifie que la colonne répète input.size(1) fois, c'est-à-dire deux fois. Les indices obtenus sont :

tenseur([[0, 0],
        [2, 2],
        [3, 3]])

Lorsque nous rencontrons des tenseurs multidimensionnels, nous pouvons toujours opérer comme ça

-------------------------------------------------- ---------------------

d'ailleurs, on peut aussi opérer sur les dimensions, c'est-à-dire dim/axis,

Nous pouvons comparer

input = torch.tensor([[ 1,  2,  3],
                      [ 4,  5,  6],
                      [ 7,  8,  9],
                      [10, 11, 12]])

x = index = torch.tensor([[1], [0]])
index=index.repeat(1, input.size(1))
c = torch.gather(input,dim=0,index=index)
d = torch.gather(input,dim=1,index=x)


#tensor([[4, 5, 6],
#       [1, 2, 3]])


#tensor([[2],
#        [4]])
input = torch.tensor([[ 1,  2,  3],
                      [ 4,  5,  6],
                      [ 7,  8,  9],
                      [10, 11, 12]])

x = index = torch.tensor([[1], [0]])
index=index.repeat(1, input.size(1))
c = torch.gather(input,dim=0,index=index)
d = torch.gather(input,dim=1,index=index)


#tensor([[4, 5, 6],
#        [1, 2, 3]])

#tensor([[2, 2, 2],
#        [4, 4, 4]])

au fait :

        tf.gather_nd peut remplacer tf.gather, mais utiliser rassembler_nd introduira plus de paramètres supplémentaires. Pour le tenseur 4-D, en supposant que nous voulons remplacer tf.gather par tf.gather_nd, nous devons extraire les éléments de l'axe correspondant. À ce stade , Indices doit simplement former une matrice correspondant à l'indice des éléments souhaités.

Pensez ensuite à un parmas avec une forme de [2,3,4,5] et
comment utiliser tf.gather(params, axis = 3, indices=[0,2]) pour afficher le même résultat avec tf.gather_nd .

import tensorflow as tf
import numpy as np
a = tf.Variable(tf.random_uniform(shape=(2, 3, 4, 5), name="v"))
nnzs = [0,2]
nnzs = np.asarray(nnzs,"int32")
#initi= np.asarray(initi,"int32")
initi =np.zeros((2,3,4,nnzs.size,4),dtype=np.int)
print(initi.shape)
for i in range(initi.shape[0]):
    for j in range(initi.shape[1]):
        for k in range (initi.shape[2]):
           for l in range(nnzs.size):
                initi[i][j][k][l] =[i,j,k,nnzs[l]]
indices = tf.Variable(initial_value = initi, name="indices")
c = tf.gather_nd(a,initi)
d = tf.gather(a,indices= nnzs,axis =3)
print(c.get_shape())
print(d.get_shape())
e = tf.equal(c,d)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run([e]))


Enfin, un tableau de valeurs vraies est généré.

Guess you like

Origin blog.csdn.net/djdjdhch/article/details/130598827