参见官方文档:https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
定义:从原tensor中获取指定dim和指定index的数据
用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的做
index做行向量,替换索引dim = 0
dim=0,用index替换行
index做行向量,替换索引dim = 1
dim=1,用index替换列
初始 index dim=1
(0,0) 2 (0,2)
(0,1) 1 (0,1)
(0,2) 0 (0,0)
为什么会有(0,0) (0,1)(0,2)
且看 index = [[2, 1, 0]]
为1×3,即使其对于元素的下标
如果index作为列向量,替换索引 dim = 0 以及 dim = 1
对于二维矩阵index,并替换索引(dim = 1)
计算:
结论:
- 输入index的shape等于输出value的shape
- 输入index的索引值仅替换该index中对应dim的index值
- 最终输出为替换index后在原tensor中的值