Pytorch中permute(),transpose(),view()函数

一.  permute()函数

  作用:

  • permute函数的作用是对tensor的维度进行转置。 

若随机生成一个1X2X3X4的四维向量,permute函数的参数表示的是转置后的向量位置。比如原向量中(1, 2, 3, 4),1的下标是0,2的下标是1,3的下标是2,4的下标是3;在x.permute(2, 1, 0, 3)中,2代表原来下表为2的数字3放在第一位(也就是N),1代表原来下表为1的数字2放在第二位(也就是C),0代表原来下表为0的数字1放在第三位(也就是H),3代表原来下表为3的数字4放在第四位(也就是W),代码如下:

import torch
import torch.nn as nn

x = torch.randn(1, 2, 3, 4)
print(x.size())      
print(x.permute(2, 1, 0, 3).size())

========================================
torch.Size([1, 2, 3, 4])   #原来的tensor
torch.Size([3, 2, 1, 4])   #转置后的tensor

二.  transpose()函数

  作用:

  • 函数的作用是也是对tensor进行转置。

但是torch.transpose()只能操作二维转置。**这个意思不是torch.transpose()只能作用于二维向量,它的意思是一次只能进行两个维度的转置,如果需要多个维度的转置,那么需要多次调用transpose()。比如上述的tensor[1,2,3,4]转置为tensor[3,4,1,2],使用transpose需要做如下:

x.transpose(0,2).transpose(1,3)
====================================
torch.Size([3, 4, 1, 2])   #转置后的tensor

二.  view()函数

  作用:

  • view()相当于reshape、resize,重新调整Tensor的形状。

1.当一个Tensor经过 tensor.transpose()、tensor.permute()等这类维度变换函数后,内存并不是连续的,而tensor.view()维度变形函数的要求是需要Tensor的内存连续,所以在运行tensor.view()之前,先使用 tensor.contiguous(),防止报错。

2.维度变换函数是进行的浅拷贝操作(只复制了指像某个对象的指针,新旧对象还是共享同一块内存)即view操作会连带原来的变量一同变形,这是不合法的,所以也会报错;---- 这个解释有部分道理,也即contiguous返回了tensor的深拷贝contiguous copy数据;

import torch
import torch.nn as nn
import numpy as np

y = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
y_tensor = torch.tensor(y)
y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
print(y_tensor.size())
print(y_tensor_trans.size())

print(y_tensor)
print(y_tensor_trans)
print(y_tensor.view(1, 3, 2)) 
==================================================
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 2, 3],
         [4, 5, 6]]])
tensor([[[1, 4]],

        [[2, 5]],

        [[3, 6]]])
tensor([[[1, 2],
         [3, 4],
         [5, 6]]])

特殊用法:

view()中一个参数定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。

猜你喜欢

转载自blog.csdn.net/m0_62278731/article/details/131934376