【pytorch】squeeze()和unsqueeze()函数介绍

在pytorch中,我们对张量Tensor的维度进行压缩或者扩充(被压缩或者扩充的维度为1),经常使用的是squeeze()函数和unsqueeze()函数

1. torch.squeeze(input, dim=None)

 

用于降维。将 input 中维度为1的部分去除,当维度大于等于2时,squeeze()无作用。

也可通过 input.squeeze( dim=None, out=None)调用。

  • input(Tensor):输入张量,即被操作目标
  • dim(int, optional):在指定维去掉一个维度。若不指定则自动寻找,指定则当指定的维度为1时去掉,不为1时则不改变

注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

 示例

# 示例1
a = torch.Tensor(1,3)
>>
tensor([[-1.37,4.56,-3.57]])
 
print a.squeeze(0) #第一个维度大小是1,所以去除
>>
tensor([-1.37,4.56,-3.57])
 
print a.squeeze(1) ##第二个维度大小是3,所以不去除
>>
tensor([[-1.37,4.56,-3.57]])
 
# 示例2
c = torch.Tensor(3,1)
print c
>>
tensor([[-3.54],
[3.09],
[0.00]])
 
print c.squeeze(0)##第一个维度大小不是1,所以不去除
>>
tensor([[-3.54],
[3.09],
[0.00]])
 
print c.squeeze(1)#第二个维度大小是1,所以去除
>>
tensor([-3.54,3.09,0.00])


# 示例3
x = torch.zeros(2, 1, 2, 1, 2)
x.size()
>>
torch.Size([2, 1, 2, 1, 2])

y = torch.squeeze(x)
y.size()
>>
torch.Size([2, 2, 2])

y = torch.squeeze(x, 0)
y.size()
>>
torch.Size([2, 1, 2, 1, 2])

2.  torch.unsqueeze(input, dim)

为pytorch中的tensor增加一个维度。

 也可通过 input.unsqueeze( dim=None, out=None)调用。

  • input(Tensor):输入张量,即被操作目标
  • dim(int, optional):在哪一维增加一个维度,dim必须被指定

示例

import torch
a = torch.arange(12).reshape([3,4])
print(a)
b = a.unsqueeze(1)
print(b)
>>
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]]])

参考官方文档:

torch.squeeze — PyTorch 2.0 documentation

torch.unsqueeze — PyTorch 2.0 documentation

扫描二维码关注公众号,回复: 16600791 查看本文章

猜你喜欢

转载自blog.csdn.net/m0_70813473/article/details/131184418