PyTorch marchant sur la fosse : fonction de collecte

Organisé à partir de documents officiels : https://www.oschina.net/action/GoToLink?url=https%3A%2F%2Fpytorch.org%2Fdocs%2Fstable%2Fgenerated%2Ftorch.gather.html%3Fhighlight%3Dgather%23torch.gather

arrière-plan 0x01

En faisant de l'apprentissage par renforcement DDQN réel (CtrlC) test (CtrlV), j'ai rencontré certaines fonctions que je ne comprends pas très bien, voici quelques interprétations.

Fonction de collecte 0x02

Rassembler, littéralement traduit par agrégation, rassemblement.

Exécutons deux exemples pour démontrer ce que cela va produire :

Créez d'abord deux tenseurs: a et b

a = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
tensor([[1, 2, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 1, 0]])

Ce qui suit montre l'effet :

>>> a
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15]])
>>> b
tensor([[1, 2, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0]])
>>> a.gather(0, b)
tensor([[ 6, 12,  3,  4,  5],
        [ 1,  7,  3,  4,  5],
        [ 1,  2,  3,  9,  5]])
>>> a.gather(1, b)
tensor([[ 2,  3,  1,  1,  1],
        [ 6,  7,  6,  6,  6],
        [11, 11, 11, 12, 11]])

La fonction de collecte équivaut en fait à sélectionner et à remplacer le tenseur d'origine.

Le premier paramètre est dim, indiquant sur quelle dimension on veut faire la sélection (par exemple, s'il faut faire une sélection sur les lignes ou les colonnes de la matrice).

Le deuxième paramètre est l'indice, qui ne doit pas nécessairement être le même que la forme de l'original a.

Ci-dessous, nous utilisons directement un schéma de réseau pour illustrer le principe de fonctionnement de la fonction de collecte.

(J'ai tapé beaucoup de mots, mais je ne peux toujours pas dire le QAQ
insérez la description de l'image ici

Acho que você gosta

Origin blog.csdn.net/weixin_43466027/article/details/117385716
Recomendado
Clasificación