Pytorch---implementation of spatial feature pyramid SPP module



1. SPP module

The SPP module is a specified spatial feature pyramid module, which was proposed by He Kaiming in his 2014 paper.
The paper address is as follows:
Paper address

The main function of this module is: in the classification network, after passing the classifier, when connecting to the fully connected layer, the shape of the fully connected layer is fixed, so the image input to the network must be resized to 224 224, otherwise when the data is transmitted to In the fully connected layer, if the weights do not match, an error will occur. And resizing all pictures to 224 224 may cause distortion of the pictures and so on. Therefore, the proposed function of the SPP module is to turn the final pooling layer of the classifier into an SPP module, so that the network can accept input of any size without resizing the input image to 224*224.
The structure of the SPP module is as follows:
Insert image description here
Just understand this sentence in the paper:
These spatial bins have sizes proportional to the image size, so the number of bins is fixed regardless of the image size.
Translation: The size of these spatial bins is proportional to the image size . Proportional, so the number of bins is fixed regardless of image size.
The size and step size of the pooling window change with the input h and w. The result is that h and w after pooling must be 4 times 4, 2 times 2, and 1 times 1.
Boxes refer to small grids

The entire process is as follows:
1: The size of the feature map after the classifier is channel hw (ignoring the batch)
2: First through the first maximum pooling, the result is 4 4 in size, and then through the second maximum pooling , the result obtained is of size 2 2, and then after the third maximum pooling, the result obtained is of size 1*1. Then flatten it and concatenate it, and you will get a 21-column vector.
No matter what the input image size is, the feature map will eventually become 256 times 21 before the fully connected layer, where 256 is the channel.

2. Implementation using pytorch

In fact, the key is to dynamically solve the k and s sizes of the pooling window.
Insert image description here

class SPP(torch.nn.Module):
    def __init__(self, input):
        super(SPP, self).__init__()
        self.pool_param = [(4, 4), (2, 2), (1, 1)]

        # 假设h和w相等,不相等的情况,h和w单独处理即可
        h = input.shape[2]
        w = input.shape[3]
        s1 = h // self.pool_param[0][0]
        k1 = h - s1 * (self.pool_param[0][0] - 1)
        self.pool_4_4 = torch.nn.MaxPool2d(kernel_size=(k1, k1), stride=(s1, s1))

        s2 = h // self.pool_param[1][0]
        k2 = h - s2 * (self.pool_param[1][0] - 1)
        self.pool_2_2 = torch.nn.MaxPool2d(kernel_size=(k2, k2), stride=(s2, s2))

        s3 = h // self.pool_param[2][0]
        k3 = h - s2 * (self.pool_param[2][0] - 1)
        self.pool_1_1 = torch.nn.MaxPool2d(kernel_size=(k3, k3), stride=(s3, s3))

    def forward(self, x):
        x1 = self.pool_4_4(x)
        x1 = torch.flatten(x1, start_dim=-2, end_dim=-1)
        x2 = self.pool_2_2(x)
        x2 = torch.flatten(x2, start_dim=-2, end_dim=-1)
        x3 = self.pool_1_1(x)
        x3 = torch.flatten(x3, start_dim=-2, end_dim=-1)
        x = torch.cat((x1, x2, x3), dim=-1)
        return x
 if __name__ == "__main__":
    vgg_model = vgg16_bn(weights = VGG16_BN_Weights.DEFAULT)
    # print(vgg_model)
    # print(list(vgg_model.features.children()))
    test = torch.rand(8, 3, 16, 16)
    model = SPP(test)
    output = model(test)
    print(output.shape)

The input is (8, 3, 16, 16). After passing through the SPP module, the size is Here
Insert image description here
16, 16 is exactly an integer multiple of 4, 2, 1. Replace other digital
inputs as (8, 3, 15, 15). After After the SPP module, the size is Insert image description here
You can try other numbers. As long as they are greater than or equal to 4, the resulting sizes will be uniform.

Guess you like

Origin blog.csdn.net/weixin_47250738/article/details/133099700