Dimension swap permute() function in pytorch

The first step is to give an explanation of permute in the official document:

https://pytorch.org/docs/stable/tensors.html?highlight=permute#torch.Tensor.permute

In the second step, I will explain the official document in vernacular.

Permute means " permutation and combination " in Chinese . permute() can swap any dimension of a tensor .

Let's look at a demo:

import torch
a = torch.randn(2,3,5)
b = a.permute(1,2,0)
print(b.shape)

Note that the parameters in permute correspond to the dimension index of tensor a . Therefore, the dimensions of permute's input parameters must be consistent with a, and can only be values ​​such as 0,1,2...,dim, so that the dimensions in a can be indexed one by one.

a.permute(1,2,0) means to put the first dimension of a to the end.

in conclusion

For the same tensor, its total number of elements is fixed, and the role of permute is to change m×n×c to any combination such as n×m×c or c×n×m.

For a two-dimensional tensor, what permute(1,0) does is transpose, which is equivalent to transpose().

Guess you like

Origin blog.csdn.net/leviopku/article/details/108752028