EfficientDet(BiFPN)(CVPR 2020)

介绍

论文:https://arxiv.org/abs/1911.09070v7
代码:https://github.com/google/automl/tree/master/efficientdet
复现代码:https://github.com/jewelc92/mmdetection/blob/3.x/projects/EfficientDet/efficientdet/bifpn.py

原理:

  • 高效的双向跨尺度连接
  • 加权特征图融合

在这里插入图片描述
图2:特征网络设计-(a)FPN引入自上而下的路径以融合从第3级到第7级(P3 - P7)的多尺度特征;(B)PANet在FPN之上添加了额外的自下而上的途径;(c)NAS-FPN使用神经架构搜索来找到不规则的特征网络拓扑,然后重复应用相同的块;(d)是我们的具有更好的准确性和效率权衡的BiFPN。

论文提出了几种针对跨尺度连接的优化方法:

  • 首先,删去那些只有一个输入的节点,因为如果一个节点只有一个输入没有特征融合的过程,那么它对旨在融合不同特征的网络的贡献就会比较小。
  • 其次,如果原始输入和输出节点处于同一层级,增加一条额外的输入路径,从而在不增加太多计算成本的情况下融合更多的特征。
  • 最后,与PANet只有一个自上而下和一个自下而上的路径不同,我们将每个双向路径视为一个特征网络层,并重复多次,从而实现更高级的特征融合。

当融合不同分辨率的特征时,常见的方法是将它们的分辨率调整为相等大小,然后进行相加。之前的方法都平等的对待不同的输入特征,但作者观察到,由于不同的特征具有不同的分辨率,通常它们对于输出的贡献也不相同。因此本文提出对于每个输入添加一个额外的权重,让网络来学习每个输入特征的重要性。

在这里插入图片描述

代码

# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch
from typing import List
 
import torch
import torch.nn as nn
from mmcv.cnn.bricks import Swish
from mmengine.model import BaseModule
 
from mmdet.registry import MODELS
from mmdet.utils import MultiConfig, OptConfigType
from .utils import (DepthWiseConvBlock, DownChannelBlock, MaxPool2dSamePadding,
                    MemoryEfficientSwish)
 
 
class BiFPNStage(nn.Module):
    '''
        in_channels: List[int], input dim for P3, P4, P5
        out_channels: int, output dim for P2 - P7
        first_time: int, whether is the first bifpnstage
        num_outs: int, BiFPN need feature maps num
        use_swish: whether use MemoryEfficientSwish
        norm_cfg: (:obj:`ConfigDict` or dict, optional): Config dict for
            normalization layer.
        epsilon: float, hyperparameter in fusion features
    '''
 
    def __init__(self,
                 in_channels: List[int],
                 out_channels: int,
                 first_time: bool = False,
                 apply_bn_for_resampling: bool = True,
                 conv_bn_act_pattern: bool = False,
                 use_meswish: bool = True,
                 norm_cfg: OptConfigType = dict(
                     type='BN', momentum=1e-2, eps=1e-3),
                 epsilon: float = 1e-4) -> None:
        super().__init__()
        assert isinstance(in_channels, list)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.first_time = first_time
        self.apply_bn_for_resampling = apply_bn_for_resampling
        self.conv_bn_act_pattern = conv_bn_act_pattern
        self.use_meswish = use_meswish
        self.norm_cfg = norm_cfg
        self.epsilon = epsilon
 
        if self.first_time:
            self.p5_down_channel = DownChannelBlock(
                self.in_channels[-1],
                self.out_channels,
                apply_norm=self.apply_bn_for_resampling,
                conv_bn_act_pattern=self.conv_bn_act_pattern,
                norm_cfg=norm_cfg)
            self.p4_down_channel = DownChannelBlock(
                self.in_channels[-2],
                self.out_channels,
                apply_norm=self.apply_bn_for_resampling,
                conv_bn_act_pattern=self.conv_bn_act_pattern,
                norm_cfg=norm_cfg)
            self.p3_down_channel = DownChannelBlock(
                self.in_channels[-3],
                self.out_channels,
                apply_norm=self.apply_bn_for_resampling,
                conv_bn_act_pattern=self.conv_bn_act_pattern,
                norm_cfg=norm_cfg)
            self.p5_to_p6 = nn.Sequential(
                DownChannelBlock(
                    self.in_channels[-1],
                    self.out_channels,
                    apply_norm=self.apply_bn_for_resampling,
                    conv_bn_act_pattern=self.conv_bn_act_pattern,
                    norm_cfg=norm_cfg), MaxPool2dSamePadding(3, 2))
            self.p6_to_p7 = MaxPool2dSamePadding(3, 2)
            self.p4_level_connection = DownChannelBlock(
                self.in_channels[-2],
                self.out_channels,
                apply_norm=self.apply_bn_for_resampling,
                conv_bn_act_pattern=self.conv_bn_act_pattern,
                norm_cfg=norm_cfg)
            self.p5_level_connection = DownChannelBlock(
                self.in_channels[-1],
                self.out_channels,
                apply_norm=self.apply_bn_for_resampling,
                conv_bn_act_pattern=self.conv_bn_act_pattern,
                norm_cfg=norm_cfg)
 
        self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
 
        # bottom to up: feature map down_sample module
        self.p4_down_sample = MaxPool2dSamePadding(3, 2)
        self.p5_down_sample = MaxPool2dSamePadding(3, 2)
        self.p6_down_sample = MaxPool2dSamePadding(3, 2)
        self.p7_down_sample = MaxPool2dSamePadding(3, 2)
 
        # Fuse Conv Layers
        self.conv6_up = DepthWiseConvBlock(
            out_channels,
            out_channels,
            apply_norm=self.apply_bn_for_resampling,
            conv_bn_act_pattern=self.conv_bn_act_pattern,
            norm_cfg=norm_cfg)
        self.conv5_up = DepthWiseConvBlock(
            out_channels,
            out_channels,
            apply_norm=self.apply_bn_for_resampling,
            conv_bn_act_pattern=self.conv_bn_act_pattern,
            norm_cfg=norm_cfg)
        self.conv4_up = DepthWiseConvBlock(
            out_channels,
            out_channels,
            apply_norm=self.apply_bn_for_resampling,
            conv_bn_act_pattern=self.conv_bn_act_pattern,
            norm_cfg=norm_cfg)
        self.conv3_up = DepthWiseConvBlock(
            out_channels,
            out_channels,
            apply_norm=self.apply_bn_for_resampling,
            conv_bn_act_pattern=self.conv_bn_act_pattern,
            norm_cfg=norm_cfg)
        self.conv4_down = DepthWiseConvBlock(
            out_channels,
            out_channels,
            apply_norm=self.apply_bn_for_resampling,
            conv_bn_act_pattern=self.conv_bn_act_pattern,
            norm_cfg=norm_cfg)
        self.conv5_down = DepthWiseConvBlock(
            out_channels,
            out_channels,
            apply_norm=self.apply_bn_for_resampling,
            conv_bn_act_pattern=self.conv_bn_act_pattern,
            norm_cfg=norm_cfg)
        self.conv6_down = DepthWiseConvBlock(
            out_channels,
            out_channels,
            apply_norm=self.apply_bn_for_resampling,
            conv_bn_act_pattern=self.conv_bn_act_pattern,
            norm_cfg=norm_cfg)
        self.conv7_down = DepthWiseConvBlock(
            out_channels,
            out_channels,
            apply_norm=self.apply_bn_for_resampling,
            conv_bn_act_pattern=self.conv_bn_act_pattern,
            norm_cfg=norm_cfg)
        # weights
        self.p6_w1 = nn.Parameter(
            torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.p6_w1_relu = nn.ReLU()
        self.p5_w1 = nn.Parameter(
            torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.p5_w1_relu = nn.ReLU()
        self.p4_w1 = nn.Parameter(
            torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.p4_w1_relu = nn.ReLU()
        self.p3_w1 = nn.Parameter(
            torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.p3_w1_relu = nn.ReLU()
 
        self.p4_w2 = nn.Parameter(
            torch.ones(3, dtype=torch.float32), requires_grad=True)
        self.p4_w2_relu = nn.ReLU()
        self.p5_w2 = nn.Parameter(
            torch.ones(3, dtype=torch.float32), requires_grad=True)
        self.p5_w2_relu = nn.ReLU()
        self.p6_w2 = nn.Parameter(
            torch.ones(3, dtype=torch.float32), requires_grad=True)
        self.p6_w2_relu = nn.ReLU()
        self.p7_w2 = nn.Parameter(
            torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.p7_w2_relu = nn.ReLU()
 
        self.swish = MemoryEfficientSwish() if use_meswish else Swish()
 
    def combine(self, x):
        if not self.conv_bn_act_pattern:
            x = self.swish(x)
 
        return x
 
    def forward(self, x):
        if self.first_time:
            p3, p4, p5 = x  # [(1,40,64,64),(1,112,32,32),(1,320,16,16)]
            # build feature map P6
            p6_in = self.p5_to_p6(p5)  # (1,64,8,8)
            # build feature map P7
            p7_in = self.p6_to_p7(p6_in)  # (1,64,4,4)
 
            p3_in = self.p3_down_channel(p3)  # (1,64,64,64)
            p4_in = self.p4_down_channel(p4)  # (1,64,32,32)
            p5_in = self.p5_down_channel(p5)  # (1,64,16,16)
 
        else:
            p3_in, p4_in, p5_in, p6_in, p7_in = x
 
        # Weights for P6_0 and P7_0 to P6_1
        p6_w1 = self.p6_w1_relu(self.p6_w1)
        weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
        # Connections for P6_0 and P7_0 to P6_1 respectively
        p6_up = self.conv6_up(
            self.combine(weight[0] * p6_in +
                         weight[1] * self.p6_upsample(p7_in)))  # (1,64,8,8)
 
        # Weights for P5_0 and P6_1 to P5_1
        p5_w1 = self.p5_w1_relu(self.p5_w1)
        weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
        # Connections for P5_0 and P6_1 to P5_1 respectively
        p5_up = self.conv5_up(
            self.combine(weight[0] * p5_in +
                         weight[1] * self.p5_upsample(p6_up)))  # (1,64,16,16)
 
        # Weights for P4_0 and P5_1 to P4_1
        p4_w1 = self.p4_w1_relu(self.p4_w1)
        weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
        # Connections for P4_0 and P5_1 to P4_1 respectively
        p4_up = self.conv4_up(
            self.combine(weight[0] * p4_in +
                         weight[1] * self.p4_upsample(p5_up)))  # (1,64,32,32)
 
        # Weights for P3_0 and P4_1 to P3_2
        p3_w1 = self.p3_w1_relu(self.p3_w1)
        weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
        # Connections for P3_0 and P4_1 to P3_2 respectively
        p3_out = self.conv3_up(
            self.combine(weight[0] * p3_in +
                         weight[1] * self.p3_upsample(p4_up)))  # (1,64,64,64)
 
        if self.first_time:
            # self.p4_level_connection和self.p4_down_channel是一样的,为什么不能直接用上面的p4_in?
            p4_in = self.p4_level_connection(p4)
            p5_in = self.p5_level_connection(p5)
 
        # Weights for P4_0, P4_1 and P3_2 to P4_2
        p4_w2 = self.p4_w2_relu(self.p4_w2)
        weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
        # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
        p4_out = self.conv4_down(
            self.combine(weight[0] * p4_in + weight[1] * p4_up +
                         weight[2] * self.p4_down_sample(p3_out)))  # (1,64,32,32)
 
        # Weights for P5_0, P5_1 and P4_2 to P5_2
        p5_w2 = self.p5_w2_relu(self.p5_w2)
        weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
        # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
        p5_out = self.conv5_down(
            self.combine(weight[0] * p5_in + weight[1] * p5_up +
                         weight[2] * self.p5_down_sample(p4_out)))  # (1,64,16,16)
 
        # Weights for P6_0, P6_1 and P5_2 to P6_2
        p6_w2 = self.p6_w2_relu(self.p6_w2)
        weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
        # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
        p6_out = self.conv6_down(
            self.combine(weight[0] * p6_in + weight[1] * p6_up +
                         weight[2] * self.p6_down_sample(p5_out)))  # (1,64,8,8)
 
        # Weights for P7_0 and P6_2 to P7_2
        p7_w2 = self.p7_w2_relu(self.p7_w2)
        weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
        # Connections for P7_0 and P6_2 to P7_2
        p7_out = self.conv7_down(
            self.combine(weight[0] * p7_in +
                         weight[1] * self.p7_down_sample(p6_out)))  # (1,64,4,4)
        return p3_out, p4_out, p5_out, p6_out, p7_out
 
 
@MODELS.register_module()
class BiFPN(BaseModule):
    '''
        num_stages: int, bifpn number of repeats
        in_channels: List[int], input dim for P3, P4, P5
        out_channels: int, output dim for P2 - P7
        start_level: int, Index of input features in backbone
        epsilon: float, hyperparameter in fusion features
        apply_bn_for_resampling: bool, whether use bn after resampling
        conv_bn_act_pattern: bool, whether use conv_bn_act_pattern
        use_swish: whether use MemoryEfficientSwish
        norm_cfg: (:obj:`ConfigDict` or dict, optional): Config dict for
            normalization layer.
        init_cfg: MultiConfig: init method
    '''
 
    def __init__(self,
                 num_stages: int,
                 in_channels: List[int],
                 out_channels: int,
                 start_level: int = 0,
                 epsilon: float = 1e-4,
                 apply_bn_for_resampling: bool = True,
                 conv_bn_act_pattern: bool = False,
                 use_meswish: bool = True,
                 norm_cfg: OptConfigType = dict(
                     type='BN', momentum=1e-2, eps=1e-3),
                 init_cfg: MultiConfig = None) -> None:
 
        super().__init__(init_cfg=init_cfg)
        self.start_level = start_level
        self.bifpn = nn.Sequential(*[
            BiFPNStage(
                in_channels=in_channels,
                out_channels=out_channels,
                first_time=True if _ == 0 else False,
                apply_bn_for_resampling=apply_bn_for_resampling,
                conv_bn_act_pattern=conv_bn_act_pattern,
                use_meswish=use_meswish,
                norm_cfg=norm_cfg,
                epsilon=epsilon) for _ in range(num_stages)
        ])
 
    def forward(self, x):
        # [(1,40,64,64),(1,112,32,32),(1,320,16,16)]
        x = x[self.start_level:]
        x = self.bifpn(x)
 
        return x

参考

https://blog.csdn.net/ooooocj/article/details/130126690?

猜你喜欢

转载自blog.csdn.net/qq_54372122/article/details/130749816