Dans Tensorflow et PyTorch,
gather
La fonction est utilisée pour agréger les éléments du tenseur d'entrée en fonction de l'index spécifié et renvoyer un nouveau tenseur.
Plus précisément, la fonction et l'utilisation de cette fonction sont les suivantes :
Fonctions dans TensorFlow tf.gather()
:
tf.gather(params, indices, axis=None, batch_dims=0, name=None)
Parmi eux, params
représente le tenseur d'entrée ; indices
représente l'index qui doit être agrégé, qui peut être une liste ou un tenseur constant ; axis
est un entier spécifiant la dimension d'agrégation, s'il n'est pas transmis, la valeur par défaut est 0 ; est un batch_dims
entier spécifiant le nombre de les dimensions du lot, généralement utilisées lors de l'agrégation de plusieurs échantillons ; name
sont un nom d'opération facultatif.
Par exemple, pour les tenseurs params = [[1, 2], [3, 4], [5, 6], [7, 8]]
qui nécessitent [0, 2, 3]
une agrégation sur la première dimension par index, vous pouvez utiliser le code suivant :
import tensorflow as tf
params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = [0, 2, 3]
output = tf.gather(params, indices, axis=0)
print(output)
# Output: [[1 2]
# [5 6]
# [7 8]]
Fonctions dans PyTorch torch.gather()
:
torch.gather(input, dim, index, out=None)
Parmi eux, input
représente le tenseur d'entrée ; dim
représente la dimension qui doit être agrégée ; index
représente l'index qui doit être agrégé, qui peut être une liste ou un tenseur constant ; out
et est un tenseur de sortie facultatif.
Par exemple, pour les tenseurs input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
qui nécessitent torch.tensor([[0], [2], [3]])
une agrégation sur la première dimension par index, vous pouvez utiliser le code suivant :
import torch
input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = torch.tensor([[0], [2], [3]])
output = torch.gather(input, 0, indices)
print(output)
# Output: tensor([[1],
# [5],
# [7]])
——————————————————————————————
différent:
L'index de tf.gather est destiné à [0, 2, 3]时,可以
l'agrégation dans la première dimension, mais
如果torch.gather的索引也设置为[0, 2, 3],则会报错:
RuntimeError: Index tensor must have the same number of dimensions as input tensor
而若torch.gather的索引设置为[[0], [2], [3]]
就会得到像上面的输出:
# Output: tensor([[1],
# [5],
# [7]])
Pour obtenir le même effet, nous utilisons d'abord
La fonction torch.unsqueeze ajoute une dimension au tenseur de [0, 2, 3] à [[0], [2], [3]], Utilisez ensuite la fonction de répétition pour modifier l'index :
import torch
input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = torch.tensor([0, 2, 3])
indices = torch.unsqueeze(indices, dim=1)
indices = indices.repeat(1, input.size(1))
output = torch.gather(input, 0, indices)
print(output)
# Output: tensor([[1, 2],
# [5, 6],
# [7, 8]])
La fonction de répétition signifie répéter cet index. Le premier paramètre est 1, ce qui signifie que la ligne reste inchangée. Le deuxième paramètre est
input.size(1) signifie que la colonne répète input.size(1) fois, c'est-à-dire deux fois. Les indices obtenus sont : tenseur([[0, 0], [2, 2], [3, 3]])
Lorsque nous rencontrons des tenseurs multidimensionnels, nous pouvons toujours opérer comme ça
-------------------------------------------------- ---------------------
d'ailleurs, on peut aussi opérer sur les dimensions, c'est-à-dire dim/axis,
Nous pouvons comparer
input = torch.tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
x = index = torch.tensor([[1], [0]])
index=index.repeat(1, input.size(1))
c = torch.gather(input,dim=0,index=index)
d = torch.gather(input,dim=1,index=x)
#tensor([[4, 5, 6],
# [1, 2, 3]])
#tensor([[2],
# [4]])
input = torch.tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
x = index = torch.tensor([[1], [0]])
index=index.repeat(1, input.size(1))
c = torch.gather(input,dim=0,index=index)
d = torch.gather(input,dim=1,index=index)
#tensor([[4, 5, 6],
# [1, 2, 3]])
#tensor([[2, 2, 2],
# [4, 4, 4]])
au fait :
tf.gather_nd peut remplacer tf.gather, mais utiliser rassembler_nd introduira plus de paramètres supplémentaires. Pour le tenseur 4-D, en supposant que nous voulons remplacer tf.gather par tf.gather_nd, nous devons extraire les éléments de l'axe correspondant. À ce stade , Indices doit simplement former une matrice correspondant à l'indice des éléments souhaités.
Pensez ensuite à un parmas avec une forme de [2,3,4,5] et
comment utiliser tf.gather(params, axis = 3, indices=[0,2]) pour afficher le même résultat avec tf.gather_nd .
import tensorflow as tf
import numpy as np
a = tf.Variable(tf.random_uniform(shape=(2, 3, 4, 5), name="v"))
nnzs = [0,2]
nnzs = np.asarray(nnzs,"int32")
#initi= np.asarray(initi,"int32")
initi =np.zeros((2,3,4,nnzs.size,4),dtype=np.int)
print(initi.shape)
for i in range(initi.shape[0]):
for j in range(initi.shape[1]):
for k in range (initi.shape[2]):
for l in range(nnzs.size):
initi[i][j][k][l] =[i,j,k,nnzs[l]]
indices = tf.Variable(initial_value = initi, name="indices")
c = tf.gather_nd(a,initi)
d = tf.gather(a,indices= nnzs,axis =3)
print(c.get_shape())
print(d.get_shape())
e = tf.equal(c,d)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([e]))
Enfin, un tableau de valeurs vraies est généré.