squeeze()
函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。
其中squeeze(0)代表若第一维度值为1则去除第一维度,squeeze(1)代表若第二维度值为1则去除第二维度。
eg1:
a = torch.Tensor(1,3)
print a
>>>
tensor([[-1.37,4.56,-3.57]])
print a.squeeze(0)
>>>
tensor([-1.37,4.56,-3.57])
print a.squeeze(1)
>>>
tensor([[-1.37,4.56,-3.57]])
eg2:
b = torch.Tensor(2,3)
print b
>>>
tensor([[-3.17,3.09,1.43],
[0.00,0.00,0.00]])
print b.squeeze(0)
>>>
tensor([[-3.17,3.09,1.43],
[0.00,0.00,0.00]])
print b.squeeze(1)
>>>
tensor([[-3.17,3.09,1.43],
[0.00,0.00,0.00]])
eg3:
c = torch.Tensor(3,1)
print c
>>>
tensor([[-3.54],
[3.09],
[0.00]])
print c.squeeze(0)
>>>
tensor([[-3.54],
[3.09],
[0.00]])
print c.squeeze(1)
>>>
tensor([-3.54,3.09,0.00])
eg4:
d = torch.rand(4,1,3)
print d
>>>
tensor([[[0.19,0.25,0.23]],
[[0.91,0.66,0.12]],
[[0.82,0.07,0.73]],
[[0.35,0.06,0.10]]])
print d.squeeze()
>>>
tensor([[0.19,0.25,0.23],
[0.91,0.66,0.12],
[0.82,0.07,0.73],
[0.35,0.06,0.10]])
eg5:
e = torch.rand(4,3,1)
print e
>>>
tensor([[[0.97],
[0.86],
[0.52]],
[[0.88],
[0.76],
[0.54]],
[[0.61],
[0.27],
[0.56]],
[[0.56],
[0.66],
[0.53]]])
print e.squeeze()
>>>
tensor([[0.97, 0.86, 0.52],
[0.88, 0.76, 0.54],
[0.61, 0.27, 0.56],
[0.56, 0.66, 0.53]])
eg6:
f = torch.rand(4,3,2)
print f
>>>
tensor([[[0.90,0.26],
[0.78,0.33],
[0.45,0.71]],
[[0.25,0.87],
[0.36,0.37],
[0.60,0.88]],
[[0.32,0.06],
[0.63,0.23],
[0.13,0.08]],
[[0.53,0.92],
[0.56,0.27],
[0.41,0.08]]])
print f.squeeze()
>>>
tensor([[[0.90,0.26],
[0.78,0.33],
[0.45,0.71]],
[[0.25,0.87],
[0.36,0.37],
[0.60,0.88]],
[[0.32,0.06],
[0.63,0.23],
[0.13,0.08]],
[[0.53,0.92],
[0.56,0.27],
[0.41,0.08]]])
unsqueeze()
函数功能:与squeeze()函数功能相反,用于添加维度。
eg:
g = torch.Tensor(3)
print g
>>>
tensor([3.27,4.56,-4.84])
print g.unsqueeze(0)
>>>
tensor([[3.27,4.56,-4.84]])
print g.unsqueeze(1)
>>>
tensor([[3.27],
[4.56],
[-4.84]])
#print g.unsqueeze() 必须指明维度
参考文章链接: