Organisé à partir de documents officiels : https://www.oschina.net/action/GoToLink?url=https%3A%2F%2Fpytorch.org%2Fdocs%2Fstable%2Fgenerated%2Ftorch.gather.html%3Fhighlight%3Dgather%23torch.gather
arrière-plan 0x01
En faisant de l'apprentissage par renforcement DDQN réel (CtrlC) test (CtrlV), j'ai rencontré certaines fonctions que je ne comprends pas très bien, voici quelques interprétations.
Fonction de collecte 0x02
Rassembler, littéralement traduit par agrégation, rassemblement.
Exécutons deux exemples pour démontrer ce que cela va produire :
Créez d'abord deux tenseurs: a et b
a = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
tensor([[1, 2, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 1, 0]])
Ce qui suit montre l'effet :
>>> a
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]])
>>> b
tensor([[1, 2, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0]])
>>> a.gather(0, b)
tensor([[ 6, 12, 3, 4, 5],
[ 1, 7, 3, 4, 5],
[ 1, 2, 3, 9, 5]])
>>> a.gather(1, b)
tensor([[ 2, 3, 1, 1, 1],
[ 6, 7, 6, 6, 6],
[11, 11, 11, 12, 11]])
La fonction de collecte équivaut en fait à sélectionner et à remplacer le tenseur d'origine.
Le premier paramètre est dim, indiquant sur quelle dimension on veut faire la sélection (par exemple, s'il faut faire une sélection sur les lignes ou les colonnes de la matrice).
Le deuxième paramètre est l'indice, qui ne doit pas nécessairement être le même que la forme de l'original a.
Ci-dessous, nous utilisons directement un schéma de réseau pour illustrer le principe de fonctionnement de la fonction de collecte.
(J'ai tapé beaucoup de mots, mais je ne peux toujours pas dire le QAQ