[论文理解] Receptive Field Block Net for Accurate and Fast Object Detection

Receptive Field Block Net for Accurate and Fast Object Detection

Brief introduction

Based on the SSD puts forward the RFB Module, the use of prior knowledge of neuroscience to explain this effect is enhanced. Is essentially a new structure designed to enhance the receptive field, and indicates that the human retina has a receptive field characteristics, the farther away from the center line of sight, which is larger receptive field, closer to the middle of the line of sight, the smaller the field experience. Based on this, the proposed RFB Module is used to simulate the visual characteristics of human beings.

RFB Module

Configuration as shown in FIG.

Why use convolution empty it?

First of all to improve the receptive field, intuitive idea is either to deepen layers, either use a larger convolution kernel, or is pooling before using convolution. Deepening layers of network parameters will change much, can not complete the task lightweight; more convolution kernel parameters as also become multi-; although pooling parameters will not increase, but will make the information loss, is not conducive to the back layer Information transfer. So here it is natural to think of a hollow convolution, neither increase the amount of parameters, but also can improve receptive field.

Why do I need this multi-branch structure?

This is to capture information of different receptive field, as mentioned earlier, the human visual field is characterized by different distances from the center of the field different receptive field, the use of a multi-branch structure, each branch field to capture a feeling, and finally to fuse by concat receptive fields of information, we can achieve the simulation of the human visual effects. Author here also gave a picture to illustrate.

Why should I propose two versions of the RFB do?

Structure RFB is the original left, right compared to the structure of the RFB conv 3 × 3 into two branch 1 × 3 and 3 × 1, one reduces the amount of the parameter, the second is a smaller increase feelings wild, so also in the simulation of the human visual system to capture smaller receptive fields.

Network architecture

Overall network structure is shown below, it is well understood.

Front is vgg19, and then separated from the branch prediction six intermediate layer, nothing better understanding in mind.

Code reproducibility

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
class RFBModule(nn.Module):
    def __init__(self,out,stride = 1):
        super(RFBModule,self).__init__()
        self.s1 = nn.Sequential(
            nn.Conv2d(out,out,kernel_size = 1),
            nn.Conv2d(out,out,kernel_size=3,dilation = 1,padding = 1,stride = stride)
        )
        self.s2 = nn.Sequential(
            nn.Conv2d(out,out,kernel_size =1),
            nn.Conv2d(out,out,kernel_size=3,padding = 1),
            nn.Conv2d(out,out,kernel_size=3,dilation = 3,padding = 3,stride = stride)
        )
        self.s3 = nn.Sequential(
            nn.Conv2d(out,out,kernel_size =1),
            nn.Conv2d(out,out,kernel_size = 5,padding =2),
            nn.Conv2d(out,out,kernel_size=3,dilation=5,padding = 5,stride = stride)
        )
        self.shortcut = nn.Conv2d(out,out,kernel_size = 1,stride = stride)
        self.conv1x1 = nn.Conv2d(out*3,out,kernel_size =1)
    def forward(self,x):
        s1 = self.s1(x)
        s2 = self.s2(x)
        s3 = self.s3(x)
        #print(s1.size(),s2.size(),s3.size())
        mix = torch.cat([s1,s2,s3],dim = 1)
        mix = self.conv1x1(mix)
        shortcut = self.shortcut(x)
        return mix + shortcut
class RFBsModule(nn.Module):
    def __init__(self,out,stride = 1):
        super(RFBsModule,self).__init__()
        self.s1 = nn.Sequential(
            nn.Conv2d(out,out,kernel_size = 1),
            nn.Conv2d(out,out,kernel_size=3,dilation = 1,padding = 1,stride = stride)
        )
        self.s2 = nn.Sequential(
            nn.Conv2d(out,out,kernel_size =1),
            nn.Conv2d(out,out,kernel_size=(1,3),padding = (0,1)),
            nn.Conv2d(out,out,kernel_size=3,dilation = 3,padding = 3,stride = stride)
        )
        self.s3 = nn.Sequential(
            nn.Conv2d(out,out,kernel_size =1),
            nn.Conv2d(out,out,kernel_size = (3,1),padding =(1,0)),
            nn.Conv2d(out,out,kernel_size=3,dilation=3,padding = 3,stride = stride)
        )
        self.s4 = nn.Sequential(
            nn.Conv2d(out,out,kernel_size =1),
            nn.Conv2d(out,out,kernel_size=3),
            nn.Conv2d(out,out,kernel_size = 3,dilation = 5,stride = stride,padding = 6)
        )
        self.shortcut = nn.Conv2d(out,out,kernel_size = 1,stride = stride)
        self.conv1x1 = nn.Conv2d(out*4,out,kernel_size =1)
    def forward(self,x):
        s1 = self.s1(x)
        s2 = self.s2(x)
        s3 = self.s3(x)
        s4 = self.s4(x)
        #print(s1.size(),s2.size(),s3.size(),s4.size())
        #print(s1.size(),s2.size(),s3.size())
        mix = torch.cat([s1,s2,s3,s4],dim = 1)
        mix = self.conv1x1(mix)
        shortcut = self.shortcut(x)
        return mix + shortcut

class RFBNet(nn.Module):
    def __init__(self):
        super(RFBNet,self).__init__()
        self.feature_1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size = 3,padding = 1),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(64,128,kernel_size = 3,padding = 1),
            nn.ReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(128,256,kernel_size = 3,padding = 1),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(256,512,kernel_size = 3,padding = 1),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.ReLU(),
        )
        
        self.feature_2 = nn.Sequential(
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(512,512,kernel_size = 3,padding = 1),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.ReLU(),
        )
        self.pre = nn.Conv2d(512,64,kernel_size = 1)
        self.fc = nn.Conv2d(512,64,kernel_size = 1)
        self.det1 = RFBsModule(out = 64,stride = 1)
        self.det2 = RFBModule(out = 64,stride = 1)
        self.det3 = RFBModule(out = 64,stride = 2)
        self.det4 = RFBModule(out = 64,stride = 2)
        self.det5 = nn.Conv2d(64,64,kernel_size = 3)
        self.det6 = nn.Conv2d(64,64,kernel_size=3)
        
    def forward(self,x):
        x = self.feature_1(x)
        det1 = self.det1(self.fc(x))
        x = self.feature_2(x)
        x = self.pre(x)
        det2 = self.det2(x)
        det3 = self.det3(det2)
        det4 = self.det4(det3)
        det5 = self.det5(det4)
        det6 = self.det6(det5)
        det1 = det1.permute(0,2,3,1).contiguous().view(x.size(0),-1,64)
        det2 = det2.permute(0,2,3,1).contiguous().view(x.size(0),-1,64)
        det3 = det3.permute(0,2,3,1).contiguous().view(x.size(0),-1,64)
        det4 = det4.permute(0,2,3,1).contiguous().view(x.size(0),-1,64)
        det5 = det5.permute(0,2,3,1).contiguous().view(x.size(0),-1,64)
        det6 = det6.permute(0,2,3,1).contiguous().view(x.size(0),-1,64)
            
            
            
        return torch.cat([det1,det2,det3,det4,det5,det6],dim = 1)

if __name__ == "__main__":
    net = RFBNet()
    x = torch.randn(2,3,300,300)
    summary(net,(3,300,300),device = "cpu")
    print(net(x).size())

Original paper: https://arxiv.org/pdf/1711.07767.pdf

Guess you like

Origin www.cnblogs.com/aoru45/p/11595079.html