pytorch中scatter()与scatter_()函数的用法与区别

Tensor.scatter_(dim, index, src, reduce=None) → Tensor
其作用是根据index将src中的值写到self中, dim决定了维度
这里需要注意的一点是self的dtype要和src的dtype相同!!!例如:

torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)

这里的self的dtype要和src的dtype相同。
函数的作用以3D的tensor举例子:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

举个具体的例子:

src = torch.arange(1, 11).reshape((2, 5))
# tensor([[ 1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10]])
index = torch.tensor([[0, 1, 2, 0],
					  [1, 0, 1, 2]])
# tensor([[0, 1, 2, 0],
#		  [1, 0, 1, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
# tensor([[1, 7, 0, 4, 0],
#         [6, 2, 8, 0, 0],
#         [0, 0, 3, 9, 0]])

# 分析:index的i取值为0-1,j的取值从0-3都可以
# self[index[0][0]][0] = self[0][0] = src[0][0] = 1
# self[index[0][1]][1] = self[1][1] = src[0][1] = 2
# self[index[0][2]][2] = self[2][2] = src[0][2] = 3
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# self[index[1][0]][0] = self[1][0] = src[1][0] = 6
# self[index[1][1]][1] = self[0][1] = src[1][1] = 7
# self[index[1][2]][2] = self[1][2] = src[1][2] = 8
# self[index[1][3]][3] = self[2][3] = src[1][3] = 9

这里还有个有意思的事情, 上面的情况是没有重叠的情况,假设index的上下两行中有重叠的元素,比如

index = torch.tensor([[0, 1, 2, 0],
					  [1, 0, 1, 0]])

注意第一行的最后一个元素与第二行的最后一个元素相同了, 都为0。(之前第二行最后一个元素为2)
这样的话上面的取值

# ...
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# ...
# self[index[1][3]][3] = self[2][3] = src[1][3] = 9
变为了
# ...
# self[index[0][3]][3] = self[0][3] = src[0][3] = 4
# ...
# self[index[1][3]][3] = self[0][3] = src[1][3] = 9

可以看到self[0][3]有了2个赋值,一次是根据i=0,j=3所赋的4;另一次是根据i=1,j=3所赋的9;根据前后顺序关系,9会把4个给覆盖掉,因此最终得到的结果变为:

tensor([[1, 7, 0, 9, 0],
        [6, 2, 8, 0, 0],
        [0, 0, 3, 0, 0]])

scatter()与scatter_()的区别在于scatter_()是原地操作的。
举例,b = a.scatter(dim, index, src)后a的值不会发生变化
相对的, b = a.scatter_(dim, index, src)后a的值发生变化, 变得与b相等

猜你喜欢

转载自blog.csdn.net/qq_43666068/article/details/130860504