torch.gather()函数

参见官方文档:https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather

定义:从原tensor中获取指定dim和指定index的数据

用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的做

index做行向量,替换索引dim = 0

dim=0,用index替换行

index做行向量,替换索引dim = 1

 dim=1,用index替换列

   初始            index          dim=1

(0,0)         2               (0,2)

(0,1)         1               (0,1)

(0,2)         0               (0,0)

为什么会有(0,0) (0,1)(0,2)

且看 index =  [[2, 1, 0]]

为1×3,即使其对于元素的下标

如果index作为列向量,替换索引 dim = 0  以及 dim = 1

对于二维矩阵index,并替换索引(dim = 1)

计算:

 

结论:

  • 输入index的shape等于输出value的shape
  • 输入index的索引值仅替换该index中对应dim的index值
  • 最终输出为替换index后在原tensor中的值

猜你喜欢

转载自blog.csdn.net/weixin_43537097/article/details/132457209