pytorch squeeze()和unsqueeze() 详解

squeeze在的中文意思压缩,unsqueeze取消压缩,unsqueeze是添加维度的意思,它的具体用法如下面代码

无论压缩还是取消压缩,都要有维度的,就是从那个方向压缩,x,y,z方向??

unsqueeze

unsqueeze(-1) ,如果是二维矩阵 它等价于unsqueeze(2)

>>> import torch
>>> a1 = torch.arange(0,12).view(3,4)
>>> print(a1.shape)
torch.Size([3, 4])
>>> print(a1)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>> a2 =a1.unsqueeze(-1)
>>> print(a2.shape)
torch.Size([3, 4, 1])
>>> print(a2)
tensor([[[ 0],
         [ 1],
         [ 2],
         [ 3]],

        [[ 4],
         [ 5],
         [ 6],
         [ 7]],

        [[ 8],
         [ 9],
         [10],
         [11]]])
>>>

unsqueeze(2)

>>> a3 =a1.unsqueeze(2)
>>> print(a3.shape)
torch.Size([3, 4, 1]) # 在增加一个维度  在第三位置插入维度
>>> print(a3)
tensor([[[ 0],
         [ 1],
         [ 2],
         [ 3]],

        [[ 4],
         [ 5],
         [ 6],
         [ 7]],

        [[ 8],
         [ 9],
         [10],
         [11]]])
>>>

unsqueeze(2) 与 unsqueeze(-1)相同

unsqueeze(0)

>>> a4 =a1.unsqueeze(0)
>>> print(a4.shape)
torch.Size([1, 3, 4])  # 插入维度 在第一位置
>>> print(a4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]])
>>>

unsqueeze(1)

>>> a5 =a1.unsqueeze(1)
>>> print(a5.shape)
torch.Size([3, 1, 4]) # 插入维度 在第二位置
>>> print(a5)
tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]]])
>>>

简单的一个二维向量  经过unsqueeze()能够得到三种可能的向量。

squeeze :

在a2 基础上进行压缩

>>> aa2 =a2.squeeze()
>>> print(aa2.shape)
torch.Size([3, 4])
>>> print(aa2)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>>  与a1一样了

去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用;

上面的squeeze功能,将torch.Size([3, 4, 1]) ---》torch.Size([3, 4])

在a4的基础上 压缩

>>> aa4 =a4.squeeze()
>>> print(a4.shape)
torch.Size([1, 3, 4])
>>> print(aa4.shape)
torch.Size([3, 4])
>>> print(aa4)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>>

猜你喜欢

转载自blog.csdn.net/Vertira/article/details/130644029