pytorch中与维度/变换相关的几个函数

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013700358/article/details/86301106
  1. torch.size ()
    先说torch.size()函数,因为后面的方法都会用这个方法看到变换后的矩阵的维度
    通过该方法,可以查看当前Tensor的维度,用法也很简单:
>>>import torch
>>>a = torch.Tensor([[[1, 2, 3], [4, 5, 6]]])
>>>a.size()
torch.Size([1, 2, 3])
  1. torch.view()
    官方文档中的解释:
    torch.view简单说,把原本的tensor尺寸,转变为你想要的尺寸,例如原尺寸为23,现在可以转为32或16等,但一定要保证等式成立,不能目标尺寸为33
    此外,也可以设其中一个尺寸为-1,表示机器内部自己计算,但同时只能有一个为-1,用法如下:
>>> b=a.view(-1, 3, 2)
>>> b
tensor([[[1., 2.],
         [3., 4.],
         [5., 6.]]])
>>> b.size()
torch.Size([1, 3, 2])
  1. torch.squeeze() / torch.unsqueeze()
    torch.squeeze(n)函数表示压缩tensor中第n维为1的维数,比如下面第一个,b.squeeze(2).size(),原始的b为上面的torch.Size([1, 3, 2]),第二维是2≠1,所以不压缩,尺寸保持不变;而若b.squeeze(0).size(),则发现第一维为1,因此压缩为3x2的tensor
>>> b.squeeze(2).size()
torch.Size([1, 3, 2])
>>> b.squeeze(0).size()
torch.Size([3, 2])

相反的,torch.unsqueeze(n)则是在第n维增加一个维数=1,如下,表示在原始的b的第二维增加一维,则尺寸变为1 * 3 * 1 * 2

>>> b.unsqueeze(2).size()
torch.Size([1, 3, 1, 2])
>>> b.unsqueeze(2)
tensor([[[[1., 2.]],

         [[3., 4.]],

         [[5., 6.]]]])
  1. torch.permute()
    这个函数表示,将原始的tensor,按照自己期望的位置重新排序,例如原始tensor的第0、1、2维分别是1、3、2,那么当我执行permute(2, 0, 1),则将第三维放在最前,第一维放在中间,第二维放在最后,也就变成了2 * 1 * 3,注意这里表示的维数的index,而不是具体几维:
>>> b.permute(2, 0, 1).size()
torch.Size([2, 1, 3])
>>> b.permute(2, 0, 1)
tensor([[[1., 3., 5.]],

        [[2., 4., 6.]]])

暂时只想到这些,若有错误还请指正,或有其他相关函数,我也将持续更新。

猜你喜欢

转载自blog.csdn.net/u013700358/article/details/86301106