[Simulation Basic Skills] [PyTorch] How to use and explain tensor deformation related functions

The transformation of Tensor is often used in PyTorch. When writing GNN-related code recently, a large number of related operations appeared. In order to avoid the need to repeat the test in the future, the relevant content, test results and personal understanding are organized below. Serve everyone while using it for yourself. At the same time, everyone is welcome to collect and help me improve together, thank you~

0. The basic tensor (Tensor) considered in this article

import torch
a = torch.randn([3, 2, 4, 5])
b = torch.randn([3, 2, 5])

1. Tensor shape

If you want to know the shape of Tensor, you can use the shape function

print(a.shape)  # 输出结果:torch.Size([3, 2, 4, 5])

2. Dimension replacement of Tensor

Tensor dimension replacement can use permute().

(1) When only two dimensions of the Tensor need to be replaced, transpose() can be used.

Although the speed of transpose() is the same as that of permute() at this time (tested and verified), all dimensions of the Tensor need to be written in permute(), and only the two dimensions that need to be replaced need to be written in transpose().

For example: a.transpose(1, 3) has the same effect as a.permute(0, 3, 2, 1), but the latter must write all four dimensions of a

print(a.transpose(1, 3).shape)  # 输出结果:torch.Size([3, 5, 4, 2])
print(a.transpose(3, 1).shape)  # transpose()的两个参数的顺序不影响结果,这里输出结果与上面相同(个人强迫症,喜欢把小的放在前面)

(2) When Tensor has more than two dimensions that need to be replaced, permute() or multiple transpose() can be used.

For example: a.transpose(1, 3).transpose(2, 3) has the same effect as a.permute(0, 3, 1, 2) but it is recommended to use permute() because the test for larger 4D tensors ( It is not meaningful to consider the speed of operation when using small-scale tensors), and found that permute() is about 50% faster than transpose().

3. Deformation of Tensor

(1) General deformation of Tensor: reshape()
(2) Temporary deformation of Tensor: view()

Note: If you only need to use the transformed result of Tensor once, you can use view() instead of reshape(), which can save memory. Because reshape() needs to open up new memory, but view() does not.
For details, please refer to: PyTorch: Detailed explanation of the difference between view() and reshape()

print(a.reshape(3*2, 4*5).shape)  # 输出结果:torch.Size([6, 20])
print(a.view(3*2, 4*5).shape)  # 此时的形状为[6, 20], 数据与Tensor a共享
c = a.view(3*2, 4*5)  # 如果将view()处理后的结果赋值给其他变量,那么仍然需要开辟新的内存存储变形后的结果,此时view()的优势将不复存在

4. Tensor splicing

在实际应用中,有时需要将张量的某一维度进行拓展。

(1) General splicing: use [] to put the ensors to be spliced ​​together (more than two can be used). Requirements: Except for the dimension to be spliced, the sizes of the other dimensions of each Tensor to be spliced ​​are the same.

Note: Tensor splicing cannot only use [], because the type of each element after splicing is Tensor, but the type of the splicing result is List, which does not meet the requirements.

d = b.reshape(3, 2, 1, 5)  # d -> shape: [3, 2, 1, 5]
e1 = torch.cat([a, d], dim=2)  # e -> shape: [3, 2, 4+1, 5]

# 更简洁高效的写法
e2 = torch.cat([a, b.view([3, 2, 1, 5)], dim=2)  # 结果与e1相同,但是更节省内存

(2) Expansion of a certain dimension through splicing

f = torch.cat([b.view(3, 2, 1, 5)]*4, dim=2)  # f -> shape: [3, 2, 4, 5]

5. (The core content of this article) Tensor splicing and deformation

Here is an introduction to "how to ensure that the data in each dimension of the Tensor will not be confused after various splicing and deformation". This is the core content of this article, and it is also a foolproof method given by individuals after testing.

Task: convert the shape to [ 3 , 2 , 5 ] [3, 2, 5][3,2,5 ] Tensor b copy extension is shaped as[ 3 , 2 , 4 , 5 ] [3, 2, 4, 5][3,2,4,5 ] Tensor, then reshape to[ 2 ∗ 4 ∗ 3 , 5 ] [2*4*3, 5][243,5 ] Tensor

method 1

c = torch.cat([b.view(3, 2, 1, 5)]*4, dim=2)  # c -> shape: [3, 2, 4, 5]
d = c.permute([1, 2, 0, 3]).reshape(-1, 5)  # d -> shape: [2*4*3, 5]

Note: The second line of code cannot be reshaped directly, otherwise the shape of the obtained Tensor is actually [ 3 ∗ 2 ∗ 4 , 5 ] [3*2*4, 5][324,5 ] , at this time, when splicing or operating with other Tensors of the same (or similar) shape, the dimensions and data will be confused.
For example, another shape is[ 2 ∗ 4 ∗ 3 , 5 ] [2*4*3, 5][243,5 ] Tensor is added to d. If the dimensions are messed up, then the corresponding elements are not added; instead, it is converted to[ 2 , 4 , 3 , 5 ] [2,4,3,5][2,4,3,5 ] The shape adds additional codes.

Method 2

c = torch.cat([b]*4, dim=0)  # c-> shape:[4*3, 2, 5]
d = c.transpose(0, 1).reshape(-1, 5)  # d -> shape: [2*4*3, 5]

Note: The second line of code here cannot be reshaped directly, otherwise the shape of the obtained Tensor is actually [ 4 ∗ 3 ∗ 2 , 5 ] [4*3*2, 5][432,5]

Before explaining the key points, let's take b as an example. The shape of b is [ 3 , 2 , 5 ] [3, 2, 5][3,2,5 ] , express b specifically (an example randomly generated by randn) is

  tensor([[[ 0.3543, -0.9587, -0.6313,  1.5067,  1.4628],
           [-0.0671,  1.1080,  0.5200, -0.2528,  0.2759]],
          [[ 0.1023, -1.7001,  0.0717,  0.2326,  0.1111],
           [-0.8022,  0.6989, -0.6247, -1.1926, -0.3376]],
          [[ 0.4788,  0.3146,  0.4460, -0.0280, -1.0335],
           [-2.4860,  0.7232,  0.5325,  0.4981, -0.0081]]])

As can be seen from the above example, "5" in shape refers to [ 5 , ] [5, ][5,] vector, the "2" in the shape means that there are 2[ 5 , ] [5, ][5,] vector, the "3" in the shape means there are 3 (2[ 5 , ] [5, ][5,] vector).

Next, let's look back at the first line of code in method 2. use

torch.cat([b]*n, dim=i)

The code of the form is to multiply the size of the i-th dimension of b (d_i) by n, indicating that the i-th dimension has n quantities of [d_i, ], so the form of multiplication is n d_i instead of d_i n . At this time, it is transformed by reshape() to get the correct [ . . . , n , di , . . . ] [..., n, d_i, ...][...,n,di,... ] but[ . . . , di , n , . . . ] [..., d_i, n, ...][...,di,n,...]。( [ . . . , n , d i , . . . ] [..., n, d_i, ...] [...,n,di,... ] means there are...(n(d_i(vector of...))))

If it is useful to you, please give me a thumbs up to encourage me to create~

おすすめ

転載: blog.csdn.net/AbaloneVH/article/details/128942839