pytorch | transpose、permute、view、contiguous、is_contiguous、reshape

  • transpose、contiguous、view
a = torch.randn(2,3) #随机产生的2*3的tensor,内存是连续的,所以打印出“真”
if a.is_contiguous():
    print("真")
else:
    print("假")
a = a.transpose(0,1)#经过transpose,维度变换后,内存不连续了,所以打印出“假”
if a.is_contiguous():
    print("真")
else:
    print("假")
a = a.contiguous()#contiguous函数将不连续内存变为连续内存,所以打印出“真”
if a.is_contiguous():
    print("真")
else:
    print("假")
a = a.view(1,6)#变换维度,但是内存依然连续,所以打印出“真“,view遇到不连续的会报错,只有连续的才不会报错
if a.is_contiguous():
    print("真")
else:
    print("假")

结果:

 

其中:is_contiguous函数是判断一个变量是否内存连续,连续返回True,不连续返回False


  • torch.reshape()
a = torch.randn(2,3) #随机产生的2*3的tensor,内存是连续的,所以打印出“真”
if a.is_contiguous():
    print("真")
else:
    print("假")
a = a.transpose(0,1)#经过transpose,维度变换后,内存不连续了,所以打印出“假”
if a.is_contiguous():
    print("真")
else:
    print("假")
a = torch.reshape(a,(1,6))#reshape相当于将contiguous和view进行了合并,无论之前是连续还是不连续,最终都是连续的,且不会报错
if a.is_contiguous():
    print("真")
else:
    print("假")

结果: 


  • 查看变量的内存地址:id()函数
a = torch.randn(2,3)
print(id(a))

 结果:


transpose、permute异同点

transpose:交换维度 

torch.manual_seed(1)
a = torch.randn(2, 3)
print(a)
b = a.transpose(0, 1)
print(b)
print('=====================')
a1 = torch.randn(2, 3, 4)
print(a1)
b1 = a1.transpose(1, 2)
print(b1)

结果:

permute:排列、置换维度,比transpose更灵活,适用于多维度

a1 = torch.from_numpy(np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]))
print(a1)
b1 = a1.permute(1, 0, 2)  # 原本是2*3*3,将第1维和第2维交换,结果变为3*2*3
print(b1)
c1 = a1.permute(1, 2, 0)  # 原本是2*3*3,第1维变为第3维,第2维变为第1维,第3维变为第2维,结果是:3*3*2
print(c1)

结果:

虽然都是维度变化,但transpose只能选择两个维度进行交换,permute则可以多维交换。

看函数原型也能看出:

def transpose(self, dim0: _int, dim1: _int) -> Tensor: ...
def permute(self, dims: _size) -> Tensor: ...

猜你喜欢

转载自blog.csdn.net/songxiaolingbaobao/article/details/112259559