利用pytorch复现spatial pyramid pooling层

sppnet不讲了,懒得写。。。直接上代码

 1 from math import floor, ceil
 2 import torch
 3 import torch.nn as nn
 4 import torch.nn.functional as F
 5 
 6 class SpatialPyramidPooling2d(nn.Module):
 7     r"""apply spatial pyramid pooling over a 4d input(a mini-batch of 2d inputs 
 8     with additional channel dimension) as described in the paper
 9     'Spatial Pyramid Pooling in deep convolutional Networks for visual recognition'
10     Args:
11         num_level:
12         pool_type: max_pool, avg_pool, Default:max_pool
13     By the way, the target output size is num_grid:
14         num_grid = 0
15         for i in range num_level:
16             num_grid += (i + 1) * (i + 1)
17         num_grid = num_grid * channels # channels is the channel dimension of input data
18     examples:
19         >>> input = torch.randn((1,3,32,32), dtype=torch.float32)
20         >>> net = torch.nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,stride=1),\
21                                       nn.ReLU(),\
22                                       SpatialPyramidPooling2d(num_level=2,pool_type='avg_pool'),\
23                                       nn.Linear(32 * (1*1 + 2*2), 10))
24         >>> output = net(input)
25     """
26     
27     def __init__(self, num_level, pool_type='max_pool'):
28         super(SpatialPyramidPooling2d, self).__init__()
29         self.num_level = num_level
30         self.pool_type = pool_type
31 
32     def forward(self, x):
33         N, C, H, W = x.size()
34         for i in range(self.num_level):
35             level = i + 1
36             kernel_size = (ceil(H / level), ceil(W / level))
37             stride = (ceil(H / level), ceil(W / level))
38             padding = (floor((kernel_size[0] * level - H + 1) / 2), floor((kernel_size[1] * level - W + 1) / 2))
39 
40             if self.pool_type == 'max_pool':
41                 tensor = (F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)).view(N, -1)
42             else:
43                 tensor = (F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=padding)).view(N, -1)
44             
45             if i == 0:
46                 res = tensor
47             else:
48                 res = torch.cat((res, tensor), 1)
49         return res
50     def __repr__(self):
51         return self.__class__.__name__ + '(' \
52             + 'num_level = ' + str(self.num_level) \
53             + ', pool_type = ' + str(self.pool_type) + ')'
54     
55 
56 class SPPNet(nn.Module):
57     def __init__(self, num_level=3, pool_type='max_pool'):
58         super(SPPNet,self).__init__()
59         self.num_level = num_level
60         self.pool_type = pool_type
61         self.feature = nn.Sequential(nn.Conv2d(3,64,3),\
62                                     nn.ReLU(),\
63                                     nn.MaxPool2d(2),\
64                                     nn.Conv2d(64,64,3),\
65                                     nn.ReLU())
66         self.num_grid = self._cal_num_grids(num_level)
67         self.spp_layer = SpatialPyramidPooling2d(num_level)
68         self.linear = nn.Sequential(nn.Linear(self.num_grid * 64, 512),\
69                                     nn.Linear(512, 10))
70     def _cal_num_grids(self, level):
71         count = 0
72         for i in range(level):
73             count += (i + 1) * (i + 1)
74         return count
75 
76     def forward(self, x):
77         x = self.feature(x)
78         x = self.spp_layer(x)
79         print(x.size())
80         x = self.linear(x)
81         return x
82 
83 if __name__ == '__main__':
84     a = torch.rand((1,3,64,64))
85     net = SPPNet()
86     output = net(a)
87     print(output)

猜你喜欢

转载自www.cnblogs.com/qinduanyinghua/p/9016235.html