pytorch --- tensor.squeeze(dim)和unsqueeze(dim)

tensor.squeeze(dim)

作用: 如果dim指定的维度的值为1,则将该维度删除,若指定的维度值不为1,则返回原来的tensor

例子:

x = torch.rand(2,1,3)
print(x)
print(x.squeeze(1))
print(x.squeeze(2))

输出:

tensor([[[0.7031, 0.7173, 0.0606]],

        [[0.6884, 0.4072, 0.0516]]])
        
tensor([[0.7031, 0.7173, 0.0606],
        [0.6884, 0.4072, 0.0516]])

tensor([[[0.7031, 0.7173, 0.0606]],

        [[0.6884, 0.4072, 0.0516]]])

如上结果所示:x.shape=[2, 1, 3] , 第一维度的值为1, 因此x.squeeze(dim=1)的输出会将第一维度去掉,其输出shape=[2,3], 第二维度值不为1, 因此x.squeeze(dim=2)输出tensor的shape不变

tensor.unsqueeze(dim)

这个函数主要是对数据维度进行扩充。给指定位置加上维数为1的维度,比如原本有个三行的数据(3,),在0的位置加了一维就变成一行三列(1,3)。还有一种形式就是b=torch.squeeze(tensor,dim) 就是在tensor中指定位置 dim 加上一个维数为1的维度

例子:

x = torch.rand(2,3)
print(x)
print("x.shape:", x.shape)
y = torch.unsqueeze(x, 1)
print(y)
print("y.shape:", y.shape)
z = x.unsqueeze(2)
print(z)
print("z.shape:", z.shape)

输出:

tensor([[0.1255, 0.7249, 0.5253],
        [0.9247, 0.4592, 0.3944]])
x.shape: torch.Size([2, 3])


tensor([[[0.1255, 0.7249, 0.5253]],

        [[0.9247, 0.4592, 0.3944]]])
y.shape: torch.Size([2, 1, 3])


tensor([[[0.1255],
         [0.7249],
         [0.5253]],

        [[0.9247],
         [0.4592],
         [0.3944]]])
z.shape: torch.Size([2, 3, 1])
[Finished in 2.6s]
发布了33 篇原创文章 · 获赞 1 · 访问量 2611

猜你喜欢

转载自blog.csdn.net/orangerfun/article/details/104012564
今日推荐