Torch.Tensor.scatter_( ) 用法解读

Tensor.scatter_(dim, index, src, reduce=None) → Tensor

将张量src中的值按参数index中指定的索引写入self。对于src中的值,当dimension!=dim时它的输出索引就是它在src中的索引。当dimension=dim时,它的输出索引就是它本身的值。

它是torch.gather()的反向操作。torch.gather()的具体用法见torch.gather() 用法解读_00000cj的博客-CSDN博客

对于一个三维的张量,self的更新方式如下

self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

示例1

import torch

src = torch.arange(1, 11).reshape((2, 5))
print(src.shape)  # torch.Size([2, 5])
print(src)
# tensor([[ 1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10]])

index = torch.tensor([[0, 1, 2, 0]])
print(index.shape)  # torch.Size([1, 4])
print(index)  # tensor([[0, 1, 2, 0]])

result = torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
print(result)
# tensor([[1, 0, 0, 4, 0],
#         [0, 2, 0, 0, 0],
#         [0, 0, 3, 0, 0]])

其中dim=0,result中1,2,3,4四个值第0维的索引分别为0,1,2,0,这也就是index中对应的四个值。

这里张量本身shape=(3,5)即dimension=2,因此这里除了第0维,其它的维度只有一个第1维。这里result中1,2,3,4四个值第1维的索引分别为0,1,2,3,这也就是index中的0,1,2,0这四个值在index本身这个张量中第1维中对应的索引。

示例2

import torch

src = torch.arange(1, 11).reshape((2, 5))
print(src.shape)  # torch.Size([2, 5])
print(src)
# tensor([[ 1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10]])

index = torch.tensor([[0, 1, 2], [0, 1, 4]])
print(index.shape)  # torch.Size([2, 3])
print(index)
# tensor([[0, 1, 2],
#         [0, 1, 4]])
result = torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
print(result)
# tensor([[1, 2, 3, 0, 0],
#         [6, 7, 0, 0, 8],
#         [0, 0, 0, 0, 0]])

这里src不变,index和dim变了。

index中第一个元素0的索引为[0][0],对应取出src[0][0]=1,因为dim=1,因此将第一维的索引即[0][0]中的第二个0换成index[0][0]的值,这里正好也是0。所以最终将self[0][0]的值换成src[0][0]即1。

同理,再看index中的最后一个元素的索引为[1][2],取出对应位置src[1][2]=8,因为dim=1,将[1][2]中的第一维换成index[1][2]的值即4,因此最终self[1][4]=src[1][2]=8。

扫描二维码关注公众号,回复: 14638709 查看本文章

参考

https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_

猜你喜欢

转载自blog.csdn.net/ooooocj/article/details/129336897
今日推荐