pytorch: In-depth understanding of reshape(), view(), transpose(), permute() functions

Preface

The view() function is a function for reconstructing tensor dimensions. permute() and transpose() are functions for tensor dimension conversion. A high-order tensor is composed of several low-order tensors. For example, the structure is (n, c, The 4th-order tensor of h, w) is composed of n 3rd-order tensors with the structure (c, h, w), and the 3rd-order tensor with the structure (c, h, w) is composed of c 3rd-order tensors with the structure (h, The second-order tensor with the structure (h, w) is composed of h first-order tensors with length w, where h is the number of rows and w is the number of columns.

1. reshape()

The reshape() function and the view() function are both functions for dimension reorganization. Their usage is similar. The difference is that the view() function can only operate on tensors, while the reshape() function can operate on both tensors and tensors. To operate on numpy arrays, the code example is as follows. For the specific principle, see the view() function.

x = np.array([1, 2, 3, 4, 5, 6])  # 一个大小为 6 的一维 numpy 数组
y = torch.Tensor([1, 2, 3, 4, 5, 6])  # 一个大小为 6 的一阶张量
print(x.reshape(2, 3))  # 重组 x 为结构为 (2, 3) 的数组
print(y.reshape(2, 3))  # 重组 y 为结构为 (2, 3) 的张量

Insert image description here

2. view()

① 1st order becomes higher order

1st level changes to 2nd level

For a 1-order tensor x, the view(h, w) operation is to take out w elements from order tensor, see the example for details.

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8])  # 一个含有 8 个元素的 1 阶张量
print(x.view(4, 2))  # 返回一个 (4, 2) 结构的 2 阶张量

Insert image description here

1st level becomes 3rd level

For a 1-order tensor x, the view(c, h, w) operation is to take out h*w elements from The tensor method is converted into a 2-order tensor with a (h, w) structure. A total of c times are taken to form a 3-order tensor with a (c, h, w) structure. See the example for details.

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])  # 一个含有 12 个元素的 1 阶张量
print(x.view(3, 2, 2))  # 返回一个 (3, 2, 2) 结构的 3 阶张量

Insert image description here

Level 1 changes to level 4

For a 1-order tensor x, the view(n, c, h, w) operation is to take out c*h*w elements from The method of converting a 3rd-order tensor to a 3rd-order tensor with a (c, h, w) structure is taken n times, and finally a 4th-order tensor with a (n, c, h, w) structure is formed. , see examples for details.

#  # 一个含有 24 个元素的 1 阶张量
x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                  13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
print(x.view(2, 2, 2, 3))  # 返回一个 (2, 2, 2, 3) 结构的 4 阶张量

Insert image description here

1 order changes to m order

For a rank 1 tensor x, perform view( in i_nin, i n − 1 i_{n-1} in1, ···, i 2 i_2 i2, i 1 i_1 i1) The operation is to take out in − 1 i_{n-1} from x each time in index order.in1* i n − 2 i_{n-2} in2*···* i 2 i_2 i2* i 1 i_1 i1elements, for in − 1 i_{n-1}in1* i n − 2 i_{n-2} in2*···* i 2 i_2 i2* i 1 i_1 i1elements are converted into one ( in − 1 i_{n-1} according to the method of converting a 1-order tensor to an m-1-order tensorin1, ···, i 2 i_2 i2, i 1 i_1 i1) structure of the m-1 order tensor, taken m times in total, and finally form a ( in i_nin, i n − 1 i_{n-1} in1, ···, i 2 i_2 i2, i 1 i_1 i1) tensor of order m of structure, where in i_ninRepresents the value of the nth index of the tensor.

② 2-order change to m-order

For a 2-order tensor x with the structure (h, w), to become a new tensor of order m, first expand the 2-order tensor into a 1-order tensor with size h*w, Then follow the method of changing order 1 to order m to become an order m tensor. Expanding by row means splicing in the w index direction. The code example of changing a 2nd order tensor to a 3rd order tensor is as follows. Use a 1st order tensor to Validation analysis.

x = torch.Tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9],
                  [10, 11, 12]])  # 一个 (4, 3) 结构的 2 阶张量
y = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])  # 一个含有 12 个元素的一阶张量
print(x.view(2, 2, 3))  # 返回一个 (2, 2, 3) 结构的 3 阶张量
print(y.view(2, 2, 3))  # 返回一个 (2, 2, 3) 结构的 3 阶张量

Insert image description here

③ 3rd order changes to mth order

For a 3rd -order tensor The 2nd-order tensor is converted into a 1st-order tensor according to the method of converting 2nd order into 1st order. Splicing by row means splicing in the h index direction. See Figure 1.1 and Figure 1.2 for examples.

Insert image description here
The code example for converting a 3rd order tensor into a 4th order tensor is shown below. A 2nd order tensor obtained after splicing is used to verify the above analysis.

x = torch.Tensor([[[1, 2, 3],
                   [4, 5, 6]],

                  [[7, 8, 9],
                   [10, 11, 12]],


                  [[13, 14, 15],
                   [16, 17, 18]],

                  [[19, 20, 21],
                   [22, 23, 24]]])  # 一个 (4, 2, 3) 结构的 3 阶张量
y = torch.Tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9],
                  [10, 11, 12],
                  [13, 14, 15],
                  [16, 17, 18],
                  [19, 20, 21],
                  [22, 23, 24]])  # 一个 (4*2, 3) 结构的 2 阶张量
print((y.view(2, 2, 2, 3)).equal(x.view(2, 2, 2, 3)))  # 两个张量转变后的结果是否相等
print(x.view(2, 2, 2, 3))  # 返回一个 (2, 2, 2, 3) 结构的 4 阶张量

Insert image description here

④ 4th order changes to mth order

For a 4th-order tensor *c, h, w), and then convert it into a 1-order tensor according to the method of changing the 3rd-order tensor to the 1st-order tensor, and finally obtain the m-order tensor by changing the 1st-order tensor to the m-order method, 4 The schematic diagrams of splicing order tensors in the c index direction are shown in Figure 2.1 and Figure 2.2.

Insert image description here
The code example for converting a 4th-order tensor into a 2nd-order tensor is shown below. A concatenated 3rd-order tensor is used to verify the above analysis.

x = torch.Tensor([[[[1, 2, 3],
                    [4, 5, 6]],

                   [[7, 8, 9],
                    [10, 11, 12]]],


                  [[[13, 14, 15],
                    [16, 17, 18]],

                   [[19, 20, 21],
                    [22, 23, 24]]]])  # 一个 (2, 2, 2, 3) 结构的 4 阶张量
y = torch.Tensor([[[1, 2, 3],
                   [4, 5, 6]],

                   [[7, 8, 9],
                    [10, 11, 12]],

                   [[13, 14, 15],
                    [16, 17, 18]],

                   [[19, 20, 21],
                    [22, 23, 24]]])  # 一个 (2*2, 2, 3) 结构的 3 阶张量

print(f'x.size() = {
      
      x.size()}')  # x(2, 2, 2, 3)
print(x.view(4, 6))  # 返回结构为 (4, 6) 的 2 阶张量
print(f'y.size() = {
      
      y.size()}')  # y(4, 2, 3)
print(y.view(4, 6))  # 返回结构为 (4, 6) 的 2 阶张量

Insert image description here

3. transpose()

transpose()The function exchanges two dimensions at a time, and the parameters are 0, 1, 2, 3, .... As the order of the tensor to be converted increases, more and more parameters become available.

② 2nd order tensor

For a 2-order tensor, the structure is (h, w), and transpose()the parameters in the corresponding function are (0, 1) two indices. The operation transpose(0, 1)is to exchange the two dimensions h and w. The result obtained is the same as the common matrix transformation. The settings are the same. See the specific code examples below.

x = torch.Tensor([[1, 2],
                  [3, 4],
                  [5, 6]])  # 一个结构为 (3, 2) 的 2 阶张量
print(f'x.size() = {
      
      x.size()}')  # 返回张量 x 的结构
y = x.transpose(0, 1)  # 交换 h, w 两个维度
# y = x.t()  # 对 x 进行转置
print(f'y.size() = {
      
      y.size()}')  # 返回张量 y 的结构
print(y)  # 打印交换维度后的张量 y,结构为 (2, 3)

Insert image description here

③ 3rd order tensor

For a third-order tensor, the structure is (c, h, w), and transpose()the parameters in the corresponding function are (0, 1, 2) 3 indexes. The operation transpose(0, 1)is to exchange the two dimensions of c and h. The schematic diagrams of the two dimensions are shown in Figure 3.1 and Figure 3.2. The exchange method of other dimensions is the same. I really don’t understand how to compare several books together.

Insert image description here
The code example for exchanging the c and h dimensions of a third-order tensor is as follows. It is not difficult to find that operating on the c and h indices of a third-order tensor transpose()means rotating with the w index direction as the axis.

x = torch.Tensor([[[1, 2, 3], [4, 5, 6]],
                  [[7, 8, 9], [10, 11, 12]],
                  [[13, 14, 15], [16, 17, 18]],
                  [[19, 20, 21], [22, 23, 24]]])  # 一个结构为 (4, 2, 3) 的 3 阶张量
print(f'x.size() = {
      
      x.size()}')  # 返回张量 x 的结构
print(x.transpose(0, 1))  # 交换张量的 c, h 维度, 结构为 (2, 4, 3)

Insert image description here

④ 4th order tensor

For a 4-order tensor, the structure is (n, c, h, w), and the parameters in the corresponding transpose() function are (0, 1, 2, 3) 4 indexes. The corresponding operations are relatively complicated. For transpose()convenience Understand that it is specifically divided into three types transpose(0, 1): , and transpose(0, 3), and transpose(1, 2). Please see the following analysis for the specific reasons.

3.4.1 The operation of transpose(0, 1) is to exchange the two dimensions n and c. The schematic diagram of exchanging the two dimensions n and c is shown in Figure 4.1 and Figure 4.2. I really don’t know how to compare it with a few books.

Insert image description here

The code example of exchanging n and c dimensions for a 4th-order tensor is as follows. The same is true for exchanging other dimensions. It is not difficult to find that the operation for a 4th-order tensor transpose(0, 1)is to regroup the channels in the n index direction. In the following code, the original tensor n index There are 2 groups in the direction, each group has 3 channels. After exchanging the n and c dimensions, it becomes 3 groups, each group has 2 channels.

x = torch.Tensor([[[[1, 2], [3, 4]],
                   [[5, 6], [7, 8]],
                   [[9, 10], [11, 12]]],

                  [[[13, 14], [15, 16]],
                   [[17, 18], [19, 20]],
                   [[21, 22], [23, 24]]]])  # 一个结构为 (2, 3, 2, 2) 的 4 阶张量
print(f'x.size() = {
      
      x.size()}')  # 返回张量 x 的结构
print(x.transpose(0, 1))  # 返回交换 n, c 维度后的张量,结构为 (3, 2, 2, 2)

Insert image description here
3.4.2 transpose(0, 3) The operation is to exchange the two dimensions n and w. The schematic diagram of exchanging the two dimensions n and w is difficult to express. Here is a code explanation. In the same way, it should be transpose(0, 2)noted that this transformation method is rarely used.

For a 4th-order tensor x with a structure of (2, 2, 2, 3), the operation is performed transpose(0, 3)to turn the original 4th-order tensor into a new 4th-order tensor with a structure of (3, 2, 2, 2). You can It is understood that the n and w of each element are exchanged while ensuring that the index of elements c and h in the original fourth-order tensor remains unchanged, which is similar to coordinate system transformation.

x = torch.Tensor([[[[1, 2, 3], [4, 5, 6]],
                   [[7, 8, 9], [10, 11, 12]]],

                  [[[13, 14, 15], [16, 17, 18]],
                   [[19, 20, 21], [22, 23, 24]]]])  # 结构为 (2, 2, 2, 3) 的 4 阶张量
print(f'x.size() = {
      
      x.size()}')  # 返回张量 x 的结构
y = x.transpose(0, 3)  # 交换 n, w 维度
print(f'y.size() = {
      
      y.size()}')  # 返回张量 y 的结构
print(y)

Insert image description here
The 3.4.3 transpose(1, 2) operation is to exchange the two dimensions of c and h. It is the same as the previous operation of exchanging the dimensions of the 3rd order tensor. The 4th order tensor only needs to exchange the dimensions of n 3rd order tensors, and the same is transpose(1, 3)true transpose(2, 3).

4. trade-ins()

permute()The function can exchange multiple dimensions at a time or rearrange the dimensions. The parameters are 0, 1, 2, 3, .... As the order of the tensor to be converted increases, there are more and more parameters. In essence, it can be understood as multiple The transpose()superposition of operations, so permute()the key to understanding the function is to understand transpose()the function. The code example is as follows.

x = torch.Tensor([[[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 10, 11, 12]],

                  [[13, 14, 15, 16],
                   [17, 18, 19, 20],
                   [21, 22, 23, 24]]])  # 一个结构为 (2, 3, 4) 的 3 阶张量
print(f'x.size() = {
      
      x.size()}')  # 返回张量 x 的结构
y = x.permute(2, 0, 1)  # 对张量 x 进行维度重排
z = x.transpose(0, 1).transpose(0, 2)  # 对张量 x 连续交换两次维度
print(y.equal(z))  # 判断张量 y 和张量 z 是否相同
print(f'z.size() = {
      
      z.size()}')  # 返回张量 z 的结构
print(z)

Insert image description here

Conclusion

Through the above analysis, it can be concluded that the dimensions of the reshpe()and view()can be set as needed when the conditions are met, while transpose()the and permute()can only transform between existing dimensions. In addition, transpose()the functions are slightly different in pytorch and numpy. In numpy The transpose()function is equivalent to the function in pytorch permute().

Guess you like

Origin blog.csdn.net/Wenyuanbo/article/details/119779521