Tensor.scatter_(dim, index, src, reduce=None) → Tensor
将张量src中的值按参数index中指定的索引写入self。对于src中的值,当dimension!=dim时它的输出索引就是它在src中的索引。当dimension=dim时,它的输出索引就是它本身的值。
它是torch.gather()的反向操作。torch.gather()的具体用法见torch.gather() 用法解读_00000cj的博客-CSDN博客
对于一个三维的张量,self的更新方式如下
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
示例1
import torch
src = torch.arange(1, 11).reshape((2, 5))
print(src.shape) # torch.Size([2, 5])
print(src)
# tensor([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10]])
index = torch.tensor([[0, 1, 2, 0]])
print(index.shape) # torch.Size([1, 4])
print(index) # tensor([[0, 1, 2, 0]])
result = torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
print(result)
# tensor([[1, 0, 0, 4, 0],
# [0, 2, 0, 0, 0],
# [0, 0, 3, 0, 0]])
其中dim=0,result中1,2,3,4四个值第0维的索引分别为0,1,2,0,这也就是index中对应的四个值。
这里张量本身shape=(3,5)即dimension=2,因此这里除了第0维,其它的维度只有一个第1维。这里result中1,2,3,4四个值第1维的索引分别为0,1,2,3,这也就是index中的0,1,2,0这四个值在index本身这个张量中第1维中对应的索引。
示例2
import torch
src = torch.arange(1, 11).reshape((2, 5))
print(src.shape) # torch.Size([2, 5])
print(src)
# tensor([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10]])
index = torch.tensor([[0, 1, 2], [0, 1, 4]])
print(index.shape) # torch.Size([2, 3])
print(index)
# tensor([[0, 1, 2],
# [0, 1, 4]])
result = torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
print(result)
# tensor([[1, 2, 3, 0, 0],
# [6, 7, 0, 0, 8],
# [0, 0, 0, 0, 0]])
这里src不变,index和dim变了。
index中第一个元素0的索引为[0][0],对应取出src[0][0]=1,因为dim=1,因此将第一维的索引即[0][0]中的第二个0换成index[0][0]的值,这里正好也是0。所以最终将self[0][0]的值换成src[0][0]即1。
同理,再看index中的最后一个元素的索引为[1][2],取出对应位置src[1][2]=8,因为dim=1,将[1][2]中的第一维换成index[1][2]的值即4,因此最终self[1][4]=src[1][2]=8。
参考
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_