PyTorch stepping on the pit: gather function

Organized from official documents: https://www.oschina.net/action/GoToLink?url=https%3A%2F%2Fpytorch.org%2Fdocs%2Fstable%2Fgenerated%2Ftorch.gather.html%3Fhighlight%3Dgather%23torch.gather

0x01 background

When doing reinforcement learning DDQN real (CtrlC) test (CtrlV), I encountered some functions that I do not understand very well. Here are some interpretations.

0x02 gather function

Gather, literally translated as aggregation, gathering.

Let's run two examples to demonstrate what it will produce:

First create two tensors: a and b

a = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]])
tensor([[1, 2, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 1, 0]])

The following shows the effect:

>>> a
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15]])
>>> b
tensor([[1, 2, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0]])
>>> a.gather(0, b)
tensor([[ 6, 12,  3,  4,  5],
        [ 1,  7,  3,  4,  5],
        [ 1,  2,  3,  9,  5]])
>>> a.gather(1, b)
tensor([[ 2,  3,  1,  1,  1],
        [ 6,  7,  6,  6,  6],
        [11, 11, 11, 12, 11]])

The gather function is actually equivalent to selecting and replacing the original tensor.

The first parameter is dim, indicating which dimension we want to make the selection on (for example, whether to make a selection on the rows or columns of the matrix).

The second parameter is the index, which does not necessarily have to be the same as the shape of the original a.

Below, we directly use a network diagram to illustrate the working principle of the gather function.

(I typed a lot of words, but I still can't tell the QAQ
insert image description here

Guess you like

Origin blog.csdn.net/weixin_43466027/article/details/117385716