torch.gather() 関数

公式ドキュメントを参照してください: https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather

定義: 元のテンソルから指定された dim と指定されたインデックスのデータを取得します

目的: バッチ テンソルから指定したインデックスのデータを取得するのに便利です。インデックスは高度にカスタマイズされており、順序どおりに実行することもできます。

インデックスは行ベクトルであり、インデックス dim = 0 を置き換えます。

dim=0、行をインデックスで置き換えます

インデックスは行ベクトルであり、インデックス dim = 1 を置き換えます。

 

 

 dim=1、列をインデックスで置き換えます

   初期インデックス dim=1

(0,0) 2               (0,2)

(0,1) 1                (0,1)

(0,2) 0                (0,0)

なぜ (0, 0) (0, 1) (0, 2) があるのか

インデックス = [[2, 1, 0]] を見てみましょう

は 1×3、つまり要素の添字です

インデックスが列ベクトルの場合、インデックスを dim = 0 および dim = 1 に置き換えます。

 

2次元行列インデックスの場合、インデックス(dim = 1)を置き換えます。

計算します:

 

 

結論は:

  • 入力インデックスの形状は出力値の形状と同じです
  • 入力インデックスのインデックス値は、インデックス内の対応する dim のインデックス値のみを置き換えます。
  • 最終出力は、インデックスを置換した後の元のテンソルの値です。

おすすめ

転載: blog.csdn.net/weixin_43537097/article/details/132457209