论文:M2Det: A Single-Shot Object Detector based on Multi-Level Feature Pyramid,pdf:https://arxiv.org/pdf/1811.04533.pdf
github: https://github.com/qijiezhao/M2Det
这篇文章主要是生成金字塔结构,后期检测,NMS等操作都是一笔带过。
模型结构 Specifically, on MS-COCO benchmark, M2Det achieves AP of 41.0 at speed of 11.8 FPS with single-scale inference strategy and AP of 44.2 with multi-scale inference strategy, which are the new state-of-the-art results among one-stage detectors.
1.介绍
检测图像中的不同大小的物体具有很大的挑战性,因此我们采用金字塔结构detect objects in an image pyramid,但这种方法大大增加了内存和时间复杂性。2 detect objects in a feature pyramid extracted from the input image,这种方法占用内存更小,并且计算复杂性低。
但这种方法依然含有局限性,因为 they simply construct the feature pyramid according to the inherent multi-scale,主要的原因是each feature map (used for detecting objects in a specific range of size) in the pyramid mainly or only consists of single-level
features will result in suboptimal detection performance.
文章的结构如上所示。
2
3 区域建议网络
模型纵览:
FFMV1融合backbone的特征成一个base featuremap.
TUM 生成多级的feature map,然后链接TMUS和FFMV2生成多级特征
SFAM组合这些特征成一个特征金字塔
The output multi-level multi-scale features are calculated as:
其中F就是FFMv1,T表示TUM
其中SFAM的结构如下:
实验:
Following the protocol in MS-COCO, we use the trainval35k set for training, which is a union of 80k images from train split and a random 35 subset of images from the 40k image val split.
实验结果贴在下面了。
在这里占了其中的前向代码,可以看看:
def forward(self,x):
loc,conf = list(),list()
base_feats = list()
if 'vgg' in self.net_family:
for k in range(len(self.base)):
x = self.base[k](x)
if k in self.base_out:
base_feats.append(x)
elif 'res' in self.net_family:
base_feats = self.base(x, self.base_out)
base_feature = torch.cat(
(self.reduce(base_feats[0]), F.interpolate(self.up_reduce(base_feats[1]),scale_factor=2,mode='nearest')),1
)
# tum_outs is the multi-level multi-scale feature
tum_outs = [getattr(self, 'unet{}'.format(1))(self.leach[0](base_feature), 'none')]
for i in range(1,self.num_levels,1):
tum_outs.append(
getattr(self, 'unet{}'.format(i+1))(
self.leach[i](base_feature), tum_outs[i-1][-1]
)
)
# concat with same scales
sources = [torch.cat([_fx[i-1] for _fx in tum_outs],1) for i in range(self.num_scales, 0, -1)]
# forward_sfam
if self.sfam:
sources = self.sfam_module(sources)
sources[0] = self.Norm(sources[0])
for (x,l,c) in zip(sources, self.loc, self.conf):
loc.append(l(x).permute(0, 2, 3, 1).contiguous())
conf.append(c(x).permute(0, 2, 3, 1).contiguous())
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
if self.phase == "test":
output = (
loc.view(loc.size(0), -1, 4), # loc preds
self.softmax(conf.view(-1, self.num_classes)), # conf preds
)
else:
output = (
loc.view(loc.size(0), -1, 4),
conf.view(conf.size(0), -1, self.num_classes),
)
return output