Pytorch---空间特征金字塔SPP模块的实现



一、SPP模块

SPP模块是指定空间特征金字塔模块,是由何凯明在2014年的论文中所提出的。
论文地址如下:
论文地址

该模块的主要作用是:在分类网络中,通过分类器之后,与全连接层连接时,全连接层的形状是固定的,所以必须将输入网络的图片resize成224224,否则当数据传输到全连接层时,权重不匹配,会发生错误。而将所以图片都resize成224224,可能会使图片失真等等。因此SPP模块提出的作用就是,将分类器最后的pooling层变成SPP模块,这样网络可以接受任意尺寸的输入,不需要将输入图片resize成224*224。
SPP模块的结构如下:
在这里插入图片描述
主要理解论文中这句话即可:
These spatial bins have sizes proportional to the image size, so the number of bins is fixed regardless of the image size.
翻译:这些空间箱的大小与图像大小成正比,因此无论图像大小如何,箱的数量都是固定的。
池化窗口的大小和步长都是跟随输入的h和w所变化的,所导致的结果就是,池化之后的h和w一定是4乘以4,2乘以2,1乘以1。
箱指的是小网格

整个的流程如下:
1: 经过分类器的feature map 的尺寸是channel h w(忽略batch)
2: 首先经过第一个最大池化,得到的结果是44大小的,然后经过第二个最大池化,得到的结果是22大小的, 然后经过第三个最大池化,得到的结果是1*1大小的。然后将其展平,拼接起来,就会得到21列的向量。
不论输入的图像尺寸为多少,最后在全连接层之前feature map都会都会变成256 乘以 21大小,其中256是channel。

二、使用pytorch实现

实际上关键就是,动态的求解出池化窗口的k和s大小
在这里插入图片描述

class SPP(torch.nn.Module):
    def __init__(self, input):
        super(SPP, self).__init__()
        self.pool_param = [(4, 4), (2, 2), (1, 1)]

        # 假设h和w相等,不相等的情况,h和w单独处理即可
        h = input.shape[2]
        w = input.shape[3]
        s1 = h // self.pool_param[0][0]
        k1 = h - s1 * (self.pool_param[0][0] - 1)
        self.pool_4_4 = torch.nn.MaxPool2d(kernel_size=(k1, k1), stride=(s1, s1))

        s2 = h // self.pool_param[1][0]
        k2 = h - s2 * (self.pool_param[1][0] - 1)
        self.pool_2_2 = torch.nn.MaxPool2d(kernel_size=(k2, k2), stride=(s2, s2))

        s3 = h // self.pool_param[2][0]
        k3 = h - s2 * (self.pool_param[2][0] - 1)
        self.pool_1_1 = torch.nn.MaxPool2d(kernel_size=(k3, k3), stride=(s3, s3))

    def forward(self, x):
        x1 = self.pool_4_4(x)
        x1 = torch.flatten(x1, start_dim=-2, end_dim=-1)
        x2 = self.pool_2_2(x)
        x2 = torch.flatten(x2, start_dim=-2, end_dim=-1)
        x3 = self.pool_1_1(x)
        x3 = torch.flatten(x3, start_dim=-2, end_dim=-1)
        x = torch.cat((x1, x2, x3), dim=-1)
        return x
 if __name__ == "__main__":
    vgg_model = vgg16_bn(weights = VGG16_BN_Weights.DEFAULT)
    # print(vgg_model)
    # print(list(vgg_model.features.children()))
    test = torch.rand(8, 3, 16, 16)
    model = SPP(test)
    output = model(test)
    print(output.shape)

输入为(8,3,16,16),经过SPP模块之后,大小为
在这里插入图片描述
这里16,16刚好是4,2,1的整数倍,更换其他数字
输入为(8,3,15,15),经过SPP模块之后,大小为在这里插入图片描述
可以再尝试其他数字,只要大于等于4,得到的尺寸都是统一的。

猜你喜欢

转载自blog.csdn.net/weixin_47250738/article/details/133099700
今日推荐