Usage of scatter_ in Pytorch

 First of all, this article only introduces the usage of using the scatter method to establish a one-hot vector\matrix.

 scatter() is no different from scatter_() in usage, except that the first method will return a new tensor, while the second method will modify the original tensor in place.

The parameters and meanings are as follows: scatter(dim, index, src)

Specifically: dim refers to which dimension to index along, if it is 0, it means indexing by row, if it is 1, it means indexing by column

                  index indicates the tensor used for indexing

                  src represents the source tensor or scalar, or the tensor or scalar that needs to be filled.

x.scatter_(dim, index, src), the process is as follows: For the index corresponding to the corresponding position of the row of x according to the value of its row or column (indicated by dim), change the 0 of this position to src. It should be noted that the value of index must be int64, otherwise an error will be reported.

 example:

import torch
x=torch.zeros(4,8)
label=torch.tensor([[1],[5],[7],[6]])
one_hot=x.scatter_(1,label,1)
print(one_hot)

 In this example, one_hot=x.scatter_(1,label,1) means, for example, for the first modification, take the elements in the label according to the column (because the first parameter of scatter_ is 1), and the first element in the label is 1, then in the first row of the one_hot matrix, fill in the position of index 1 (because the third parameter in scatter_ is 1), and so on.

This example comes from: scatter_() function in pytorch - daremosiranaihana - 博客园

Guess you like

Origin blog.csdn.net/qq_43438974/article/details/127076760