pytorch下的unsqueeze和squeeze用法

版权声明:转载注明出处 https://blog.csdn.net/york1996/article/details/81875508

#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉

#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加
 

import torch

a=torch.rand(2,3,1)             
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1])

print(a.size())                 #torch.Size([2, 3, 1])
print(a.squeeze().size())       #torch.Size([2, 3])

print(a.squeeze(0).size())      #torch.Size([2, 3, 1])

print(a.squeeze(-1).size())     #torch.Size([2, 3])
print(a.size())                 #torch.Size([2, 3, 1])
print(a.squeeze(-2).size())     #torch.Size([2, 3, 1])
print(a.squeeze(-3).size())     #torch.Size([2, 3, 1])
print(a.squeeze(1).size())      #torch.Size([2, 3, 1])
print(a.squeeze(2).size())      #torch.Size([2, 3])
print(a.squeeze(3).size())      #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

print(a.unsqueeze().size())     #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size())   #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size())   #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size())   #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size())    #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size())    #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size())    #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size())    #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size())       #torch.Size([2, 3])

猜你喜欢

转载自blog.csdn.net/york1996/article/details/81875508
今日推荐