First, give the link to the official document:
https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
Then, I use the vernacular to translate the official document.
Gather, as the name implies, gather and gather. It's a bit like queuing in military training, arranging the teams in the order the instructor wants .
There is a more appropriate analogy: The role of gather is to search based on the index , and then the search results are returned in the form of a tensor matrix .
1. Get a tensor:
import torch
a = torch.arange(15).view(3, 5)
a = tensor([
[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
2. Generate a search rule:
( The elements of tensor b are all indexes corresponding to tensor a )
b = torch.zeros_like(a)
b[1][2] = 1
b[0][0] = 1
b = tensor(
[[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]])
3. Start searching according to the dimension dim:
c = a.gather(0, b) # dim=0
d = a.gather(1, b) # dim=1
c= tensor([
[5, 1, 2, 3, 4],
[0, 1, 7, 3, 4],
[0, 1, 2, 3, 4]])
d=tensor([
[ 1, 0, 0, 0, 0],
[ 5, 5, 6, 5, 5],
[10, 10, 10, 10, 10]])
Ok, it should be a bit hard to see here.
If dim=0, b is relative to a, and it stores the index of the 0th dimension;
If dim=1, b is relative to a, and it stores the index of the first dimension;
Let me give a chestnut, when dim=0 , the element of b[ 0 ][0] is 1 , then it wants to find the element in a[ 0 ][ 1 ];
When dim=1 , the element of b[0][ 0 ] is 1 , so it wants to find the element in a[ 1 ][ 0 ];
The final output can be regarded as a query on a, that is , the elements are all elements in a, and the query index is stored in b . The output size is the same as b.
Find a network diagram to describe, where index corresponds to b, src corresponds to a, and the values in the grid are all reduced by 1. The left picture corresponds to dim=0, and the right picture corresponds to dim= 1.