Understanding and usage of gather in PyTorch

Understanding and usage of gather function in PyTorch

write in front

I was reading the source code of an Internet site recently. I encountered gather in the code and didn’t understand it. Then I went to Baidu to find an explanation from a Zhihu blogger https://zhuanlan.zhihu.com/p/352877584. After checking it, I came up with my own idea. Understood,
I will use some pictures from the above Zhihu bloggers. If there is any offense, please send a private message!

Give an example to better understand

# 准备一个3X3的tensor
import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

The contents of tensor_0 are:

tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

torch.gather(0,index)

Use the row vector index to replace the row index (dim=0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

Output result:

tensor([[9, 7, 5]])

understand:

for tensor_1

[[2, 1, 0]]

"2" corresponds to the coordinate representation of tensor_1 as (0,0), then gather(0,index) will use the value "2" to replace the row coordinate "0" of the value's coordinate (0,0), and get the new The coordinates (2,0) in tensor_0:

[[ 3,  4,  5],
 [ 6,  7,  8],
 [ 9, 10, 11]]

contains "9"; similarly, "1" corresponds to the coordinates of tensor_1 as (0,1), then gather(0,index) will replace the coordinates of this value with the value "1" ( 0,1) row coordinate "0", the new coordinate (1,1) is expressed as "7" in tensor_0; similarly, the corresponding coordinate of "0" in tensor_1 is expressed as (0,2), then gather (0,index) will use the value "0" to replace the row coordinate "0" of the value's coordinate (0,2), and the new coordinate (0,2) will be represented as "5" in tensor_0
Therefore the output is [9, 7, 5]

torch.gather(1,index)

Use column vector index to replace column index (dim=1)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

Output result:

tensor([[5, 4, 3]])

understand:

for tensor_1

[[2, 1, 0]]

"2" corresponds to the coordinate expressed as (0,0) in tensor_1, then gather(1,index) will use the value "2" to replace the column coordinate "0" of the value's coordinate (0,0), and get the new The coordinates (0,2) in tensor_0:

[[ 3,  4,  5],
 [ 6,  7,  8],
 [ 9, 10, 11]]

contains "5"; similarly, "1" corresponds to the coordinates of tensor_1 as (0,1), then gather(1,index) will replace the coordinates of this value with the value "1" ( 0,1) column coordinate "1", the new coordinate (0,1) is expressed as "4" in tensor_0; similarly, the corresponding coordinate of "0" in tensor_1 is expressed as (0,2), then gather (1,index) will use the value "0" to replace the column coordinate "2" of the value's coordinate (0,2), and the new coordinate (0,0) will be represented as "3" in tensor_0
Therefore the output is [5, 4, 3]

Guess you like

Origin blog.csdn.net/weixin_46088099/article/details/125473771