pytorch view()和reshape() 详解

前言

如果没有时间看下去,这里直接告诉你结论:

  • 两者都是用来重塑tensor的shape的。

  • view只适合对满足连续性条件(contiguous)的tensor进行操作,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称和引用,返回值是视图。

  • reshape对适合对满足连续性条件(contiguous)的tensor进行操作返回值是视图,否则返回副本(此时等价于先调用contiguous()方法在使用view())

  • 考虑内存的开销而且要确保重塑后的tensor与之前的tensor共享存储空间,那就使用view

  • view能干的reshape都能干 如果只是重塑一个tensor的shape 那就无脑选择reshape

pytorch Tensor 介绍

想要深入理解view与reshape的区别,首先要理解一些有关PyTorch张量存储的底层原理,比如tensor的头信息区(Tensor)和存储区 (Storage)以及tensor的步长Stride

Tensor 文档链接

Tensor 存储结构介绍

tensor数据采用头信息区(Tensor)和存储区 (Storage)分开存储的形式,如图1所示。变量名以及其存储的数据是分为两个区域分别存储的。比如,我们定义并初始化一个tensor,tensor名为A,A的形状size、步长stride、数据的索引等信息都存储在头信息区,而A所存储的真实数据则存储在存储区。另外,如果我们对A进行截取、转置或修改等操作后赋值给B,则B的数据共享A的存储区,存储区的数据数量没变,变化的只是B的头信息区对数据的索引方式。如果听说过浅拷贝和深拷贝的话,很容易明白这种方式其实就是浅拷贝。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-N2OUa60I-1661605448596)(https://note.youdao.com/yws/res/4/WEBRESOURCEc24383dc6161107a6217f220c1813a44)]

代码示例如下:

import torch
a = torch.arange(5)  # 初始化张量 a 为 [0, 1, 2, 3, 4]
b = a[2:]            # 截取张量a的部分值并赋值给b,b其实只是改变了a对数据的索引方式
print('a:', a)
print('b:', b)
print('ptr of storage of a:', a.storage().data_ptr())  # 打印a的存储区地址
print('ptr of storage of b:', b.storage().data_ptr())  # 打印b的存储区地址,可以发现两者是共用存储区
 
print('==================================================================')
 
b[1] = 0    # 修改b中索引为1,即a中索引为3的数据为0
print('a:', a)
print('b:', b)
print('ptr of storage of a:', a.storage().data_ptr())  # 打印a的存储区地址,可以发现a的相应位置的值也跟着改变,说明两者是共用存储区
print('ptr of storage of b:', b.storage().data_ptr())  # 打印b的存储区地址
 

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-d5robjjk-1661605448598)(https://note.youdao.com/yws/res/c/WEBRESOURCEd715646616ddbd9ab2e857e3934bea1c)]

Tensor的步长(stride)属性

torch的tensor也是有步长属性的,说起stride属性是不是很耳熟?是的,卷积神经网络中卷积核对特征图的卷积操作也是有stride属性的,但这两个stride可完全不是一个意思哦。tensor的步长可以理解为从索引中的一个维度跨到下一个维度中间的跨度

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2RVZKAQT-1661605448599)(https://note.youdao.com/yws/res/7/WEBRESOURCE9995364b2e5d8cf3a377a8952e9b68b7)]

我们看下如下例子:


import torch
a = torch.arange(6).reshape(2, 3)  # 初始化张量 a
b = torch.arange(6).view(3, 2)     # 初始化张量 b
print('a:', a)
print('stride of a:', a.stride())  # 打印a的stride
print('b:', b)
print('stride of b:', b.stride())  # 打印b的stride

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Z9CgkX3S-1661605448600)(https://note.youdao.com/yws/res/8/WEBRESOURCEa4349796e7e709732c62fb42c6635888)]

Tensor View 理解

参考链接

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yJJRFyYh-1661605448601)(https://note.youdao.com/yws/res/d/WEBRESOURCE71d76e815ae8727f47ab028bfa1263ad)]

大致意思是:

返回的张量共享相同的数据,并且必须具有相同数量的元素,但可能具有不同的大小。对于要查看的张量,新视图大小必须与其原始大小和步幅兼容,即每个新视图维度必须是原始维度的子空间,或者满足以下连续条件:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-g3br8pUJ-1661605448602)(https://note.youdao.com/yws/res/8/WEBRESOURCEece0d25023a8856d1343a7c255a9c8c8)]

否则需要先使用contiguous()方法将原始tensor转换为满足连续条件的tensor,然后就可以使用view方法进行shape变换了。或者直接使用reshape方法进行维度变换,但这种方法变换后的tensor就不是与原始tensor共享内存了,而是被重新开辟了一个空间。

如何理解tensor是否满足连续条件呐?下面通过一系列例子来慢慢理解下

查看tensor的stride、size属性
如下例子:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4ylVLZbx-1661605448603)(https://note.youdao.com/yws/res/4/WEBRESOURCEf56c32ec5f8974b6a2656f7a506cb534)]

我们可以看到结果是满足连续性的 stride[0] = 3 = 1X3

下面我们看看不满足连续性的例子:

import torch
a = torch.arange(9).reshape(3, 3)     # 初始化张量a
b = a.permute(1, 0)  # 对a进行转置
print('struct of b:\n', b)
print('size   of b:', b.size())    # 查看b的shape
print('stride of b:', b.stride())  # 查看b的stride
 
'''   运行结果   '''
struct of b:
tensor([[0, 3, 6],
        [1, 4, 7],
        [2, 5, 8]])
size   of b: torch.Size([3, 3])
stride of b: (1, 3)   # 注:此时不满足连续性条件

输出a和b的存储区来看一下有没有什么不同:

import torch
a = torch.arange(9).reshape(3, 3)             # 初始化张量a
print('ptr of storage of a: ', a.storage().data_ptr())  # 查看a的storage区的地址
print('storage of a: \n', a.storage())        # 查看a的storage区的数据存放形式
b = a.permute(1, 0)                           # 转置
print('ptr of storage of b: ', b.storage().data_ptr())  # 查看b的storage区的地址
print('storage of b: \n', b.storage())        # 查看b的storage区的数据存放形式
 
'''   运行结果   '''
ptr of storage of a:  1899603060672
storage of a: 
  0
 1
 2
 3
 4
 5
 6
 7
 8
[torch.LongStorage of size 9]
ptr of storage of b:  1899603060672
storage of b: 
  0
 1
 2
 3
 4
 5
 6
 7
 8
[torch.LongStorage of size 9]

由结果可以看出,张量a、b仍然共用存储区,并且存储区数据存放的顺序没有变化,这也充分说明了b与a共用存储区,b只是改变了数据的索引方式。那么为什么b就不符合连续性条件了呐(T-T)?其实原因很简单,我们结合图3来解释下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6SrlS5US-1661605448603)(https://note.youdao.com/yws/res/2/WEBRESOURCE1ae714a0b792817b02987eb595d77e02)]

Torch.reshape

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SG9BbXZu-1661605448604)(https://note.youdao.com/yws/res/c/WEBRESOURCE3162ae9c77fe9339b2b6cd326a21e5cc)]

作用:与view方法类似,将输入tensor转换为新的shape格式。
但是reshape方法更强大,可以认为a.reshape = a.view() + a.contiguous().view()。
即:在满足tensor连续性条件时,a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同

参考文章和链接

https://blog.csdn.net/Flag_ing/article/details/109129752
https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
https://stackoverflow.com/questions/49643225/whats-the-difference-between-reshape-and-view-in-pytorch

猜你喜欢

转载自blog.csdn.net/BXD1314/article/details/126562501