Pytorch之scatter(),gather()函数

scatter函数:
scatter_(dim, index, src) → Tensor
Parameters:

  • dim (int) – the axis along which to index index (LongTensor) – the
  • indices of elements to scatter, can be either empty or the same size
    of src. When empty, the operation returns identity
  • src (Tensor) – the source element(s) to scatter, incase value is not specified
  • value (float) – the source element(s) to scatter, incase src is not specified

(1)维度dim:整数,可以是0,1,2,3…

(2)索引数组index:索引数组是一个tensor,其中的数据类型是整数,表示位置

(3)原数组input:也是一个tensor,其中的数据类型任意

gather函数:
torch.gather(input, dim, index, out=None) → Tensor
Parameters:

  • input (Tensor) – 源张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
  • out (Tensor, optional) – 目标张量

猜你喜欢

转载自blog.csdn.net/m0_46429066/article/details/105014616