torch.squeeze()和torch.unsqueeze()

原文:https://www.jianshu.com/p/e0ddfcf2e72d

1. torch.squeeze(tensor)

和numpy等库函数中的squeeze()函数作用一样,torch.squeeze()函数的作用是压缩一个tensor的维数为1的维度,使该tensor降维变成最紧凑的形式:

 1 In [1]: import numpy as np                                                      
 2 
 3 In [2]: import torch                                                            
 4 
 5 In [3]: a = torch.arange(9).view(3,1,3)                                         
 6 
 7 In [4]: a                                                                       
 8 Out[4]: 
 9 tensor([[[0, 1, 2]],
10 
11         [[3, 4, 5]],
12 
13         [[6, 7, 8]]])
14 
15 In [5]: a.size()                                                                
16 Out[5]: torch.Size([3, 1, 3])
17 
18 In [6]: a.dim()                                                                 
19 Out[6]: 3
20 
21 In [7]: b = torch.squeeze(a)                                                    
22 
23 In [8]: b                                                                       
24 Out[8]: 
25 tensor([[0, 1, 2],
26         [3, 4, 5],
27         [6, 7, 8]])
28 
29 In [9]: b.size()                                                                
30 Out[9]: torch.Size([3, 3])
31 
32 In [10]: b.dim()                                                                
33 Out[10]: 2

同样numpy中功能一样:

 1 In [11]: c = np.arange(9).reshape(1,3,1,3)                                      
 2 
 3 In [12]: c                                                                      
 4 Out[12]: 
 5 array([[[[0, 1, 2]],
 6 
 7         [[3, 4, 5]],
 8 
 9         [[6, 7, 8]]]])
10 
11 In [13]: c.shape, c.ndim                                                        
12 Out[13]: ((1, 3, 1, 3), 4)
13 
14 In [14]: d = np.squeeze(c)                                                      
15 
16 In [15]: d                                                                      
17 Out[15]: 
18 array([[0, 1, 2],
19        [3, 4, 5],
20        [6, 7, 8]])
21 
22 In [16]: d.shape, d.ndim                                                        
23 Out[16]: ((3, 3), 2)
2. torch.unsqueeze(tensor, dim)

unsqueeze()函数的功能是在tensor的某个维度上添加一个维数为1的维度,这个功能用view()函数也可以实现。这一功能尤其在神经网络输入单个样本时很有用,由于pytorch神经网络要求的输入都是mini-batch型的,维度为[batch_size, channels, w, h],而一个样本的维度为[c, w, h],此时用unsqueeze()增加一个维度变为[1, c, w, h]就很方便了。

 1 In [17]: b                                                                      
 2 Out[17]: 
 3 tensor([[0, 1, 2],
 4         [3, 4, 5],
 5         [6, 7, 8]])
 6 
 7 In [18]: b.size(), b.dim()                                                      
 8 Out[18]: (torch.Size([3, 3]), 2)
 9 
10 In [20]: b_un = torch.unsqueeze(b, 0)                                           
11 
12 In [21]: b_un                                                                   
13 Out[21]: 
14 tensor([[[0, 1, 2],
15          [3, 4, 5],
16          [6, 7, 8]]])
17 
18 In [22]: b_un.size(), b_un.dim()                                                
19 Out[22]: (torch.Size([1, 3, 3]), 3)
20 
21 In [23]: b_un_un = torch.unsqueeze(b_un, 3)                                     
22 
23 In [24]: b_un_un                                                                
24 Out[24]: 
25 tensor([[[[0],
26           [1],
27           [2]],
28 
29          [[3],
30           [4],
31           [5]],
32 
33          [[6],
34           [7],
35           [8]]]])
36 
37 In [25]: b_un_un.size(), b_un_un.dim()                                          
38 Out[25]: (torch.Size([1, 3, 3, 1]), 4)

猜你喜欢

转载自www.cnblogs.com/KAKAFEIcoffee/p/11963989.html