公式ドキュメントを参照してください: https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
定義: 元のテンソルから指定された dim と指定されたインデックスのデータを取得します
目的: バッチ テンソルから指定したインデックスのデータを取得するのに便利です。インデックスは高度にカスタマイズされており、順序どおりに実行することもできます。
インデックスは行ベクトルであり、インデックス dim = 0 を置き換えます。
dim=0、行をインデックスで置き換えます
インデックスは行ベクトルであり、インデックス dim = 1 を置き換えます。
dim=1、列をインデックスで置き換えます
初期インデックス dim=1
(0,0) 2 (0,2)
(0,1) 1 (0,1)
(0,2) 0 (0,0)
なぜ (0, 0) (0, 1) (0, 2) があるのか
インデックス = [[2, 1, 0]] を見てみましょう
は 1×3、つまり要素の添字です
インデックスが列ベクトルの場合、インデックスを dim = 0 および dim = 1 に置き換えます。
2次元行列インデックスの場合、インデックス(dim = 1)を置き換えます。
計算します:
結論は:
- 入力インデックスの形状は出力値の形状と同じです
- 入力インデックスのインデックス値は、インデックス内の対応する dim のインデックス値のみを置き換えます。
- 最終出力は、インデックスを置換した後の元のテンソルの値です。