[Selected] Analysis of scatter_add_ function in pytorch

Foreword:

In PyTorch, the scatter_add_ function is an operation for scatter (scatter) addition by index. Specifically, it adds the value of the input tensor to the output tensor at the specified position at the given index. The underscore (_) in this function indicates that it is an in-place operation and will modify the original tensor without creating a new tensor.

The function signature is as follows:

python
torch.scatter_add_(input, dim, index, source)
  • input: The input tensor, the target of the spread operation.
  • dim: Specifies the dimension along which the scatter operation should be performed.
  • index: Tensor containing the index of the values ​​to be spread to the input.
  • source: The source tensor containing the values ​​to be spread.

Let’s break down the parameters of this function in detail:

  • input: This is the target tensor to perform the spread operation on. It is a mutable tensor, that is, its value will be modified after the spread operation.
  • dim: Indicates in which dimension the scatter operation is performed. For example, if dim=0, the scatter operation is performed on the first dimension. The size in this dimension must be consistent with the size of the index tensor.
  • index: This is a tensor containing the index that specifies where in the input tensor the spread operation is to be performed. The size of the index tensor must be the same as the size of the source tensor. Each element in the index tensor is an index that adds the corresponding element of the source tensor to the corresponding position of the input tensor.
  • source: This is a source tensor containing the values ​​to be spread. Its size must be consistent with the size of the index tensor.
    The working principle of this function is that for a given dim, it will spread the value of the source tensor to the corresponding position of the input tensor according to the index in the index tensor, and will values ​​are added.

This function is useful in many situations, especially when working with sparse tensors or doing backpropagation in neural networks.

When using scatter_add_ it usually involves dealing with sparse tensors. Here is a simple example showing how to operate with scatter_add_:

import torch

# 创建一个空的张量作为目标张量
size = (4, 4)
input_tensor = torch.zeros(size)

# 创建索引张量和源张量
index = torch.tensor([[0, 1, 2, 0], [2, 0, 1, 3]])
source = torch.tensor([1.0, 2.0, 3.0, 4.0])

# 在指定的维度上使用 scatter_add_
dim = 0
input_tensor.scatter_add_(dim, index, source)

# 打印结果
print(input_tensor)

In this example, we create a 4x4 empty tensor input_tensor, and then define an index tensor index and a source tensor source. Next, we use scatter_add_ to perform a scatter operation on dimension dim=0, adding the value in source to input_tensor according to the index in index. Finally, print the modified input_tensor.
The value of input_tensor here is obtained by spreading and adding the values ​​​​of the index and source tensors. This is a simple example, and more complex operations and data structures may be involved in practice.

Guess you like

Origin blog.csdn.net/weixin_42628609/article/details/134449281