About pytorch tensor dimension conversion encyclopedia

# view()    转换维度
# reshape() 转换维度
# permute() 坐标系变换
# squeeze()/unsqueeze() 降维/升维
# expand()   扩张张量
# narraw()   缩小张量
# resize_()  重设尺寸
# repeat(), unfold() 重复张量
# cat(), stack()     拼接张量

1 tensor.view()

view() is used to change the shape of the tensor , but does not change the element values ​​in the tensor .
Usage 1:
For example, you can use view to transform a tensor with shape (2, 3) into a tensor with shape (3, 2);

import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x.view(3, 2)    

The above operation is equivalent to first flattening the tensor with shape **(2, 3) into (1, 6), and then into (3, 2).**

Usage 2:
The number of elements in the tensor remains unchanged before and after conversion. If the dimension of a certain dimension in view() is -1 , it means that the dimension of this dimension is adaptively adjusted according to the total number of elements and the size of other dimensions . Note that the dimension of at most one dimension in view() can be set to -1 .

z = x.view(-1,2)

image.png

For example:
In convolutional neural networks, views are often used in the fully connected layer to stretch the tensor dimensions:
Assume that the input feature is a 4-dimensional tensor of B C H*W , where B represents batchsize and C represents features. The number of channels, H and W represent the height and width of the feature. Before sending the feature to the fully connected layer, .view will be used to convert it into a 2-dimensional tensor of B*(C H W) , that is, the batch will remain unchanged, but Convert each feature into a one-dimensional vector.

2 tensor.reshape()

reshape() is used in the same way as view().
image.png

3 tensor.squeeze()和tensor.unsqueeze()

3.1 tensor.squeeze() dimensionality reduction

(1) If the squeeze() brackets are empty, all dimensions with a dimension of 1 in the tensor will be compressed , such as reducing the dimension of the tensor of 1, 2, 1, 9 to 2, 9 dimensions; if there is no 1 in the dimension If the dimension of the tensor is the same, the source dimension will remain unchanged. For example, if a 2 3 4-dimensional tensor is squeezed, the dimension will not change after conversion.
(2) If squeeze(idx) is used , the corresponding idx-th dimension in the tensor will be compressed. For example, if squeeze(2) is performed on tensors of 1, 2, 1, and 9, the dimension will be reduced to 1, 2, and 9 dimensions. Tensor; if the dimension of idx dimension is not 1, the dimension will not change after squeeze.
For example:
image.png

3.2 tensor.unsqueeze(idx)升维

Dimension upgrading is performed in the idx dimension, and the tensor is upgraded from the original dimension n to n+1 dimension . For example, the dimension of a tensor is 2*3. After unsqueeze(0), it becomes a tensor with dimensions of 1, 2, and 3.
image.png

4 tensor.permute()

Coordinate system transformation, that is, matrix transposition , is used in the same way as transpose of numpy array . The parameter numbers in permute() brackets refer to the index values ​​of each dimension. Permute is a technique often used in deep learning. Generally, the feature tensor of BCHW is converted into the feature tensor of BHWC through transposition , that is, the feature depth is converted to the last dimension by calling **tensor.permute(0 , 2, 3, 1)**realized.
torch.transpose can only operate the transpose of a 2D matrix, while the permute() function can transpose any high-dimensional matrix;
simple understanding: permute() is equivalent to operating several dimensions of tensor at the same time, and transpose can only act on tensor at the same time. of two dimensions.

image.png

Although both permute and view/reshape can convert tensors into specific dimensions, their principles are completely different, so pay attention to the distinction. After view and reshape processing, the order of elements in the tensor will not change, but the arrangement of elements will change after permute transposition because the coordinate system changes.

5 torch.cat([a,b],dim)

When performing tensor splicing in the dim dimension , attention should be paid to keeping the dimensions consistent .
Suppose a is a two-dimensional tensor of h1 w1, b is a two-dimensional tensor of h2 w2, torch.cat(a,b,0) means splicing in the first dimension , that is, splicing in the column direction , so w1 and w2 must equal. torch.cat(a,b,1) means splicing in the second dimension, that is, splicing in the row direction, so h1 and h2 must be equal .
Suppose a is a two-dimensional tensor of c1 h1 w1, and b is a two-dimensional tensor of c2 h2 w2. torch.cat(a,b,0) means splicing in the first dimension, that is, splicing in the channel dimension of the feature, Other dimensions must remain consistent, that is, w1=w2, h1=h2. torch.cat(a,b,1) means splicing in the second dimension, that is, splicing in the column direction. It must be ensured that w1=w2, c1=c2; torch.cat(a,b,2) means splicing in the third dimension. , that is, when splicing in the row direction, h1=h2, c1=c2 must be ensured;
image.png

6 tensor.expand()

Expand a tensor to expand a single dimension to a larger size through value copying . Using the expand() function will not change the original tensor, and the result needs to be reassigned. The following are specific examples:
Take a two-dimensional tensor as an example: tensor is a 1 n or n 1-dimensional tensor. Call tensor.expand(s, n) or tensor.expand(n, s) respectively in the row direction and column direction. direction to expand.
The fill-in parameter of expand() is size

image.png

7 tensor.narrow(dim, start, len)

The narrow() function plays a role in filtering data in a certain dimension.

torch.narrow(input, dim, start, length)->Tensor

input is the tensor that needs to be sliced, dim is the slice dimension, start is the starting index, and length is the slice length. The actual application is as follows:

image.png

8 tensor.resize_()

Size changes, truncate the tensor to the dimensions after resize_.
image.png

9 tensor.repeat()

tensor.repeat(a,b) copies the entire tensor a copy in the row direction and b copies in the column direction.

image.png

reference:

Functions related to tensor dimension changes in pytorch (continuously updated) - weili21's article - Zhihu
https://zhuanlan.zhihu.com/p/438099006

[pytorch tensor tensor dimension conversion (tensor dimension conversion)]
https://blog.csdn.net/x_yan033/article/details/104965077

Guess you like

Origin blog.csdn.net/Alexa_/article/details/134171416