分析pytorch中的view和reshape

在pytorch编写的代码中,view方法和reshape方法是非常常用的。但是经过我的一次亲手撸代码的经历,发现事情并不那么简单,那就一起来看一下我为何这么说吧?

一、基础认知

在此之前,希望对以下几点我们是有共同认知的:

  1. (B, C, H, W)这个形式的tensor通常都是作为网络的输入,而每个维度都是有其意义的,比如第一维的B就是不同的图片(暂且认为是输入的图片),也就是一个批次中的batch size大小,C是tensor的通道, H和W分别是tensor 的spatial size,即高和宽。比如下面这幅图就可以看作(2,3,2,2
    示例
  2. 还有一点需要认知的就是(2×2×3, 2)与(3×2×2, 2)是不同的。有人会说那不都是(12, 2)嘛,这样说的人需要回到第一点再思考,也可借助下面这幅图理解【所以看一些tensor的shape的时候尽量要拆开来做注释】:
    (2×2×3, 2)与(3×2×2, 2)的区别
    到这里看不懂还可以看下面的例子。

二、官方文档

现在就可以来看看官方文档中怎么解释这两个用法的:

2.1 view()

view(*shape) → Tensor
view()的pytorch官方文档

2.2 reshape()

torch.reshape(input, shape) → Tensor
reshape()的pytorch官方文档

三、结论

  • view函数和reshape函数在我们使用过程中的基本功能是类似的,reshape可在不连续的情况下不会报错
  • 对于-1的用法,两者都是:如果输入(B, C, H, W)经过view(B, -1)或者reshape(B, -1)都会变成(B, W×H×C),而不是(B, C×H×W),这是需要特别注意的。【即反着来的】

猜你喜欢

转载自blog.csdn.net/laizi_laizi/article/details/108682858