CVPR17论文有感:A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection

A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection

还是那个老故事,即如何解决Deep检测跟踪器训练时正样本(尤其在occlusion和deformation情况下的hard positive)不足的问题。这个问题的紧迫性在于这样样本的缺失导致训练所得网络在occlusion、deformation等appearance variation情况下的鲁棒性不足。

传统解决此问题有两个方向的思路:1. 尽可能收集和建立越来越大越来越全的数据库,期待这个数据库能够把方方面面的variation都囊括(比如COCO超过10K的汽车样本with variations)。可是,occlusion和deformation具有long-tail的特性,即便再大的dataset也很难得到一个足够全的数据库;2. Hard samples mining in Loss function,比如类似focal loss,这样做仅仅是增加了在训练时hard样本的单体对loss的贡献,但是依然不能解决有些很罕见样本根本收集不到的问题。

在解决此问题上,本文与SINT++的思路完全一样,即通过Adversarial的概念去直接学习从hard positive的long-tail distribution中学习采样生成现实中不存在的hard positive样本,学会如何去遮挡一些真的easy positive。注意,这里的生成依旧不是输入一张图,输出一张图的传统GAN模式。因为这样的传统GAN模式依旧摆脱不了对样本的需求。这里做的是在CNN的feature map上进行遮挡,而不是pixel-wise的逐像素生成,这样一来就对adversary训练的样本需求减少很多。 注意,A-Fast-RCNN和SINT++都是在feature map上遮挡而不是在原图上遮挡。Adversary的概念体现在generator拼命生成discriminator(classifier、detector)无法识别的hard positive,而discriminator则拼命去识别generator扔过来的尽可能难的hard positive。

 

  • 论文梗概:

本文是自实现better data utilization and provide a new way of hard positive data enhancement的角度入手得到新detector的。换言之,本文的贡献主要集中在如何用adversary的概念在已有sample基础上对抗地(generator try to predict occlusion and deformation that will degrade the detector,while the detector learn to overcome it)学会遮挡,生成新hard样本。而detector本身就是一个Fast-RCNN。

一个detector的核心能力体现在invariance to appearance variance(包括illumination,deformation,occlusion,intra-class variation,etc.)。为了使得训练得到检测器获得这种invariance的能力,最原始的思路是暴力收集建立尽可能大的dataset,使得这个dataset有丰富的hard样本可以cover各种可能的variance。但是,deformation和occlusion等variance具有long-tail distribution的特性,即有大部分variance是极其罕见的,想要完全收集齐是不现实的。

那么该怎么办呢?那么就通过生成的方法,依据现有的easy样本去创造hard正样本。即通过adverary学习如何在easy sample上面进行遮挡以使其变成hard positive。

  • 深度检测器背景调研

两大阵营的发展:

  • One-stage (Spatially continuous: Sliding window):Overfeat à YOLO à SSD à DenseBox;
  • Two-stage (Spatially discrete: Region Proposal): R-CNN à Fast-RCNN(compute the conv features of the entire image once and share it among proposals) à Faster-RCNN(E2E with RPN) à R-FCN

新detector设计方向:

  • (网络设计)Go Deeper on the CNN main body:ResNet, Inception-ResNet, ResNetXt;
  • (网络设计)Contextual reasoning & top-down mechanism: segmentation first, CNN feature hierarchy skip connection, CNN feature hierarchy lateral connection;
  • (数据增强)Get better and more data: Hard-positive mining, etc.
  • 技术细节:

传统的Faster-RCNN有两部分输出:分类类别和bbox位置。相对应地,在计算loss时也由两部分组成:(F()是detector,X是一个proposal的feature)

那么本文的工作,就是在训练时在原网络上增加一个adversarial network,此时训练的loss变为:

上述A(x)是将原proposal X的feature通过adversarial network生成的一个‘新’的feature。La loss的关键是前面的负号,这也是adversary对抗概念的关键:即,使得detector分类产生越大误差的A()说明造假成功,则loss越小;而使得detector分类越准确的A()说明造价不力,则loss越大。也就是说,在adversarial训练的过程中,使得A()越来越能够产生破坏detector分类性能的新样本,而在对抗下,detector越来越能适应A()带来的新样本。

           A()的造假是在原先真样本上动两种手脚:1 occlusion,2. Deformation(partial rotation),将其改造为一个更challenge更rare的假样本。产生occlusion的部分叫做Adversarial Spatial Dropout Network(ASDN);而产生partial rotation的部分叫做Adversarial Spatial Transform Network(ASTN)。两个网络串联,ASTN在ASDN之后,流水线一样地对一个输入的真样本进行改造,一次对其加之occlusion和rotation。

值得注意的是:

  • ASDN和ASTN都是feature-level operation而不是image level operation,即不是像GAN一样输入一个图,pixel-wise地generate并输出一个图;
  • Adversarial network只是作为generator在训练的时候存在于网络中,在训练过程中对抗地提高detector的性能,而在测试时将被去掉,仅剩下原本的Faster-RCNN进行测试,只不过这样训练得到的detector比普通训练下得到的更难强壮;
  • ASDN从概念到implements和2018 CVPR VITAL一文的设计都太像了。。。。。。不得不说,后来者VITAL需要避嫌啊。。。不过VITAL里面提到的feature level的masking是为了筛选出robust的feature,去掉discriminative的feature,而这里是说要用同样1D的mask(掩码)来给原feature map加权上遮挡。Anyway,都是一个masking+drop-out的过程。。。太像了。。。

 

  1. ASDN for Occlusion

ASDN就是输入原某个proposal在roi pooling之后的feature输入两层FC(IMageNet pretrained Fc6 and FC7)生成一个二值掩码图mask,这个掩码图标识着feature map中哪些部分应当被遮挡掉(对应位置清零)。再将这个mask与原输入feature map(FM)相乘(dropout)得到一个masked FM。

  • 网络训练

如上图可见ASDN的网络结构非常简单,但是训练过程非常复杂。采用stage-wise分阶段的训练,即先训Fast-RCNN不加ASDN for 10K;然后fix Fast-RCNN的所有参数,单独训练ASDN for 10k;最终将两者全部上线串联起来E2E训练。

  • 单独训练Fast-RCNN:去掉ASDN,训练10K。
  • 单独训练ASDN:fix Fast-RCNN所有参数,单独训练ASDN。

       Training Sample Generation for ASDN

ASDN的训练样本应当是(X,M)对儿,一个input proposal feature map X对应一个best occlusion mask M。所谓best,就是遮挡以至于迷惑detector效果最好的一个mask。这个M是9选一选出来的。如何选的呢?假设X是d*d spatial dimension,那么用一个d/3 * d/3的window在X上有9个实力丁window位置,将这个9个mask apply到X上得到9个masked feature map,run 9次Fast-RCNN的forward pass,选出classification分数最低的哪一个mask作为这个X的最好mask M,构成一个(X,M)对儿。

           将这些(X,M)对儿作为样本,按照如下BCE loss训练ASDN:

易见就是一个加负号的BCE loss。A(X)是ASDN输出的mask,M是GT最优mask,计算这两个mask之间的BCE loss。

     Binarization of ASDN Output Mask

ASDN训练如上得到的输出mask是一个连续值的heat-map,不是我们期待的如GT的binary mask。那么就需要单独对map进行二值化。这个二值化是通过importance sampling实现的,而不是通过硬thresholding,这样做是为了二值化过程的stochasticity和diversity。

           具体做法是选择原mask中值最大的1/2个element,将这些element中随机1/3的值变为1,其余2/3变为0;其余的小的1/2全为1。注意,这个binarization的操作是不可导的!无法BP

    Joint Training of Fast-RCNN & ASDN

如上图,进行一个FP:Fast-RCNN conv层提特征,做ROI pooling,再把每一个pool出来的roi的特征图经过ASDN进行masking,再讲mask之后的occluded feature输入分类器。

           在反传的时候,gradient在dropout即矩阵乘的那一步分流,一部分顺着流回conv layers,另一支无法流入ASDNFC里,因为binarization不可导!这时对ASDN的训练采用了类似REINFORCE的DRL训练方法,具体没细讲。

2. ASTN for deformation(rotation):

基于A. Zisserman大神的STN(全网络可导),本文提出ASTN,其实就是完全的照搬STN的三个components,而adversary得概念体现在训练loss加负号,即原先STN是为了经过deformation让样本更容易被分类loss更小,而现在ASTN的implementation目的是为了把样本变得更难于分类,loss更大。

如上图,ASTN同STN一样由localisation network,grid generator,sampler三个component组成。其中localisation network的输入时roi pooling之后一个roi的feature map,然后经过localisation network的三个FC(前两个是Imagenet pre-trained的Fc6和Fc7)生成三个参数:1. Rotation degree2. Translation Distance3. Scaling factor。在ASTN中,我们只关注对原feature进行rotation带来的deformation。这三个参数输入到Grid Generator和Sampler中得到一个mask,使其与原FM相乘即可。在ASTN中,训练的就是Localisation Network(三个FC的参数)

网络训练

ASTN因为STN全面可导,所以训练可以和Fast-RCNN一起joint train。但是应当也是stage-wise的(作者没详述),即先10K训练Faster-RCNN without ASTN,再fix Fast-RCNN训ASTN(loss加负号) 10K,然后E2E train 10K。

值得注意的有两点(关于如何rotate的限制):

  • 只允许正负10度的rotation degree:太大的rotation导致目标倒立完全不能分类,这个generator造假矫枉过正了;
  • 在roi pooled FM的channel维上4等分,每一等分的channels FM单独rotate:因为FM channel-wise的每一层是一个filter刷出来的结果,刻画着某一种特征。把这些特征按照不同角度rotate,带来的deformation效果更好

 

3. Adversarial Fusion

猜你喜欢

转载自blog.csdn.net/Trasper1/article/details/81541589
今日推荐