Detailed explanation of pytorch view() and reshape()

foreword

If you don’t have time to read on, here’s the conclusion:

  • Both are used to reshape the shape of the tensor.

  • View is only suitable for operating tensors that meet the contiguous condition, and this operation will not open up new memory space, but only generate a new alias and reference to the original storage space, and the return value is a view.

  • reshape is suitable for operations on tensors that satisfy the contiguous condition (contiguous), and the return value is a view, otherwise a copy is returned (at this time, it is equivalent to calling the contiguous() method before using view())

  • Consider the memory overhead and ensure that the reshaped tensor shares storage space with the previous tensor, then use view

  • View is capable of reshape. If you just reshape the shape of a tensor, then choose reshape without thinking.

Introduction to pytorch Tensor

If you want to understand the difference between view and reshape in depth, you must first understand some underlying principles of PyTorch tensor storage, such as tensor's header information area (Tensor) and storage area (Storage) and tensor's step size Stride

Tensor documentation link

Tensor storage structure introduction

Tensor data is stored separately in the header information area (Tensor) and storage area (Storage), as shown in Figure 1. The variable name and its stored data are divided into two areas and stored separately. For example, we define and initialize a tensor, the tensor name is A, the shape size, stride, data index and other information of A are stored in the header information area, and the real data stored by A is stored in the storage area. In addition, if we intercept, transpose, or modify A and then assign it to B, then the data of B shares the storage area of ​​A, and the data quantity in the storage area remains unchanged, only the index of the header information area of ​​B to the data changes. Way. If you have heard of shallow copy and deep copy, it is easy to understand that this method is actually a shallow copy.

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-N2OUa60I-1661605448596)(https://note.youdao.com/yws/res/4/WEBRESOURCEc24383dc6161107a6217f220c1813a44)]

The code example is as follows:

import torch
a = torch.arange(5)  # 初始化张量 a 为 [0, 1, 2, 3, 4]
b = a[2:]            # 截取张量a的部分值并赋值给b,b其实只是改变了a对数据的索引方式
print('a:', a)
print('b:', b)
print('ptr of storage of a:', a.storage().data_ptr())  # 打印a的存储区地址
print('ptr of storage of b:', b.storage().data_ptr())  # 打印b的存储区地址,可以发现两者是共用存储区
 
print('==================================================================')
 
b[1] = 0    # 修改b中索引为1,即a中索引为3的数据为0
print('a:', a)
print('b:', b)
print('ptr of storage of a:', a.storage().data_ptr())  # 打印a的存储区地址,可以发现a的相应位置的值也跟着改变,说明两者是共用存储区
print('ptr of storage of b:', b.storage().data_ptr())  # 打印b的存储区地址
 

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-d5robjjk-1661605448598)(https://note.youdao.com/yws/res/c/WEBRESOURCEd715646616ddbd9ab2e857e3934bea1c)]

Tensor's stride property

The tensor of torch also has a step attribute. Does the stride attribute sound familiar? Yes, the convolution operation of the convolution kernel on the feature map in the convolutional neural network also has a stride attribute, but these two strides have completely different meanings. The step size of tensor can be understood as the span from one dimension in the index to the middle of the next dimension

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-2RVZKAQT-1661605448599)(https://note.youdao.com/yws/res/7/WEBRESOURCE9995364b2e5d8cf3a377a8952e9b68b7)]

Let's look at the following example:


import torch
a = torch.arange(6).reshape(2, 3)  # 初始化张量 a
b = torch.arange(6).view(3, 2)     # 初始化张量 b
print('a:', a)
print('stride of a:', a.stride())  # 打印a的stride
print('b:', b)
print('stride of b:', b.stride())  # 打印b的stride

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-Z9CgkX3S-1661605448600)(https://note.youdao.com/yws/res/8/WEBRESOURCEa4349796e7e709732c62fb42c6635888)]

Tensor View understanding

reference link

[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-yJJRFyYh-1661605448601)(https://note.youdao.com/yws/res/d/WEBRESOURCE71d76e815ae8727f47ab028bfa1263ad)]

Roughly means:

The returned tensors share the same data and must have the same number of elements, but may be of different sizes. For tensors to be viewed, the new view size must be compatible with its original size and stride, i.e. each new view dimension must be a subspace of the original dimension, or satisfy the following continuous condition:
[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-g3br8pUJ-1661605448602)(https://note.youdao.com/yws/res/8/WEBRESOURCEece0d25023a8856d1343a7c255a9c8c8)]

Otherwise, you need to use the contiguous() method to convert the original tensor into a tensor that satisfies the continuous condition, and then use the view method to perform shape transformation. Or directly use the reshape method to perform dimension transformation, but the tensor transformed by this method does not share memory with the original tensor, but is re-opened a space.

How to understand tensorWhether the continuous condition is satisfiedNah? Let's understand through a series of examples

View the stride and size attributes of tensor
as follows:
[External link picture transfer failed, the source site may have an anti-theft link mechanism, it is recommended to save the picture and upload it directly (img-4ylVLZbx-1661605448603)(https://note.youdao.com/yws/res/4/WEBRESOURCEf56c32ec5f8974b6a2656f7a506cb534)]

We can see that the result is a continuous stride[0] = 3 = 1X3

Let's look at an example where continuity is not satisfied:

import torch
a = torch.arange(9).reshape(3, 3)     # 初始化张量a
b = a.permute(1, 0)  # 对a进行转置
print('struct of b:\n', b)
print('size   of b:', b.size())    # 查看b的shape
print('stride of b:', b.stride())  # 查看b的stride
 
'''   运行结果   '''
struct of b:
tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])
size   of b: torch.Size([3, 3])
stride of b: (1, 3)   # 注:此时不满足连续性条件

Output the storage areas of a and b to see if there is any difference:

import torch
a = torch.arange(9).reshape(3, 3)             # 初始化张量a
print('ptr of storage of a: ', a.storage().data_ptr())  # 查看a的storage区的地址
print('storage of a: \n', a.storage())        # 查看a的storage区的数据存放形式
b = a.permute(1, 0)                           # 转置
print('ptr of storage of b: ', b.storage().data_ptr())  # 查看b的storage区的地址
print('storage of b: \n', b.storage())        # 查看b的storage区的数据存放形式
 
'''   运行结果   '''
ptr of storage of a:  1899603060672
storage of a: 
  0
 1
 2
 3
 4
 5
 6
 7
 8
[torch.LongStorage of size 9]
ptr of storage of b:  1899603060672
storage of b: 
  0
 1
 2
 3
 4
 5
 6
 7
 8
[torch.LongStorage of size 9]

It can be seen from the results that tensors a and b still share the storage area, and the order of storing data in the storage area has not changed. This also fully demonstrates that b and a share the storage area, and b only changes the indexing method of the data. So why does b not meet the continuity condition (TT)? In fact, the reason is very simple, let us explain it in conjunction with Figure 3:
[External link picture transfer failed, the source site may have an anti-leeching mechanism, it is recommended to save the picture and upload it directly (img-6SrlS5US-1661605448603)(https://note.youdao.com/yws/res/2/WEBRESOURCE1ae714a0b792817b02987eb595d77e02)]

Torch.reshape

[External link picture transfer failed, the source site may have an anti-leeching mechanism, it is recommended to save the picture and upload it directly (img-SG9BbXZu-1661605448604)(https://note.youdao.com/yws/res/c/WEBRESOURCE3162ae9c77fe9339b2b6cd326a21e5cc)]

Function: Similar to the view method, convert the input tensor to a new shape format.
But the reshape method is more powerful, you can think of a.reshape = a.view() + a.contiguous().view().
That is: when the tensor continuity condition is met, the result returned by a.reshape is the same as a.view(), otherwise the result returned is the same as a.contiguous().view()

Reference Articles and Links

https://blog.csdn.net/Flag_ing/article/details/109129752
https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
https://stackoverflow.com/questions/49643225/whats-the-difference-between-reshape-and-view-in-pytorch

Guess you like

Origin blog.csdn.net/BXD1314/article/details/126562501