1. The scatter() function of pytorch
The role of scatter() and scatter_() is the same, except that scatter() will not directly modify the original Tensor, while scatter_() will.
There are 3 parameters for scatter(dim, index, src)
- dim: along which dimension to index
- index: the element index used for scatter
- src: The source element used for scatter, which can be a scalar or a tensor
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 2)
tensor([[2., 2., 2., 2., 2.],
[0., 2., 0., 2., 0.],
[2., 0., 2., 0., 2.]])
In the above example, the first parameter '0' represents the 0th dimension, which is the row in the example. The second parameter'torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])' represents the index of the first element. In short, those places need to be changed. Value operations, for example, the [0,1,2,0,0] in the first row in the first row means to change the values of the elements at 0,1,2,0,0 in the same column. The third parameter '2' means that the final change value is set to 2.
Reference: https://www.cnblogs.com/dogecheng/p/11938009.html
Extend to 3d matrix
>>> torch.zeros(2, 3, 5).scatter_(1, torch.tensor([[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]],[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]]), 2)
tensor([[[2., 2., 2., 2., 2.],
[0., 2., 0., 2., 0.],
[2., 0., 2., 0., 2.]],
[[2., 2., 2., 2., 2.],
[0., 2., 0., 2., 0.],
[2., 0., 2., 0., 2.]]])
Compare it with the first dimension parameter to determine which dimension to transform. (Generally used when the semantic segmentation mask is converted to one-hot label)