pytorch 中 contiguous() 函数理解

pytorch 中 contiguous() 函数理解




文章抄自 Pytorch中contiguous()函数理解-清晨的光明-CSDN ,仅用作个人学习和记录。如有帮助,请关注原作者,给原作者点赞。


引言

在 pytorch 中,只有很少几个操作是不改变 tensor 的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据

会改变元数据的操作是:

  • narrow()
  • view()
  • expand()
  • transpose()

在使用 transpose() 进行转置操作时,pytorch 并不会创建新的、转置后的 tensor,而是修改了 tensor 中的一些属性(也就是元数据),使得此时的 offset 和 stride 是与转置 tensor 相对应的。转置的 tensor 和原 tensor 的内存是共享的!

transpose() 后改变元数据,代码示例:

x = torch.randn(3, 2)
y = torch.transpose(x, 0, 1)
print("修改前:")
print("x-", x)
print("y-", y)

print("\n修改后:")
y[0, 0] = 11
print("x-", x)
print("y-", y)

运行结果:

修改前:
x- tensor([[-0.5670, -1.0277],
           [ 0.1981, -1.2250],
           [ 0.8494, -1.4234]])
y- tensor([[-0.5670,  0.1981,  0.8494],
           [-1.0277, -1.2250, -1.4234]])
 
修改后:
x- tensor([[11.0000, -1.0277],
           [ 0.1981, -1.2250],
           [ 0.8494, -1.4234]])
y- tensor([[11.0000,  0.1981,  0.8494],
           [-1.0277, -1.2250, -1.4234]])

可以看到,改变了 y 的元素值的同时,x 的元素的值也发生了变化

因此可以说,x 是 contiguous 的,但 y 不是(因为内部数据不是通常的布局方式)。注意不要被 contiguous 的字面意思“连续的”误解,tensor 中数据还是在内存中一块区域里,只是布局的问题!

为什么这么说:因为,y 里面数据布局的方式和从头开始创建一个常规的 tensor 布局的方式是不一样的。这个可能只是 python 中之前常用的浅拷贝,y 还是指向 x 变量所处的位置,只是说记录了 transpose 这个变化的布局


使用 contiguous()

如果想要断开这两个变量之间的依赖(x 本身是 contiguous 的),就要使用 contiguous() 针对 x 进行变化,感觉上就是我们认为的深拷贝
调用 contiguous() 时,会强制拷贝一份 tensor,让它的布局和从头创建的一模一样,但是两个 tensor 完全没有联系

代码示例:

x = torch.randn(3, 2)
y = torch.transpose(x, 0, 1).contiguous()
print("修改前:")
print("x-", x)
print("y-", y)

print("\n修改后:")
y[0][0] = 11
print("x-", x)
print("y-", y)

运行结果:

修改前:
x- tensor([[ 0.9730,  0.8559],
           [ 1.6064,  1.4375],
           [-1.0905,  1.0690]])
y- tensor([[ 0.9730,  1.6064, -1.0905],
           [ 0.8559,  1.4375,  1.0690]])
 
修改后:
x- tensor([[ 0.9730,  0.8559],
           [ 1.6064,  1.4375],
           [-1.0905,  1.0690]])
y- tensor([[11.0000,  1.6064, -1.0905],
           [ 0.8559,  1.4375,  1.0690]])

可以看到,当对 y 使用了 .contiguous() 后,改变 y 的值时,x 没有任何影响!

后记

一般来说这一点不用太担心,当遇到需要调用 contiguous() 的地方,运行时会提示你:

RuntimeError: input is not contiguous

这个时候只需要在该变量后面加上 .contiguous() 就可以了!

猜你喜欢

转载自blog.csdn.net/weixin_51524504/article/details/129145009