【DETR】3、Conditional DETR | 拆分 content 和 spatial 来实现对 DETR 的加速

在这里插入图片描述

论文:Conditional DETR for Fast Training Convergence

代码:https://github.com/Atten4Vis/ConditionalDETR

出处:ICCV2021

一、Conditional DETR 是怎么被提出来的

DETR 方法自被提出以来,就以其端到端的结构得到了很大的关注,DETR 方法首次将 Transformer 结构引入了目标检测任务中,且使用二分匹配的方式让一个目标只有一个输出框,免去了 NMS 的操作,使用 object queries 的方式免去了手工设置 anchor 的过程。

但是 DETR 的缺点也很明显:

  • 收敛速度慢:需要 500 个 epoch 才能收敛(是 Faster RCNN 的 10~20 倍)
  • 对小目标效果不好:DETR 使用的是 CNN 抽取的最后一层特征(1/32),分辨率较低,细节信息缺失严重,所以小目标检出效果较差

DETR 的基本流程:

  • CNN backbone 提取图像的 feature
  • Transformer Encoder 通过 self-attention 建模全局关系对 feature 进行增强
  • Transformer Decoder 的输入是 object queries 和 Transformer encoder 的输出,主要包含 self-attention 和 cross-attention 的过程。Cross- attention 主要是将 object query 当做查询,encoder feature 当做 key,为了查询和 query 有关的区域。Self-attention 主要是对每个 query 做交互,让每个 query 能看到其他 query 在查询什么东西,从而不重复,类似与 NMS 的作用
  • 对 Decoder 输出的查询好了的 query,使用 FFN 提取出 bbox 的位置和类别信息

针对 DETR 的收敛慢的问题,Conditional DETR 对 DETR 的结构进行了进一步的探索,从 DETR 的 decoder 入手,回顾一下 DETR 中提到的 decoder 的主要作用是关注目标的边界位置(extremity),如下图所示,包括边界位置的大象鼻子、大象蹄子、斑马蹄子等,都是对边界定位很重要的边界信息。

在这里插入图片描述

然后我们再看看 DETR 不同训练 epoch 的模型对边界位置的定位能力,如图 1 所示,在 DETR 训练 50 个 epoch 时(第二行),从空间 attention 图中可以明显的看出来模型并不能很好的找到目标的边界,在训练到 500 个 epoch 的时候(第三行),模型就能很好的找到目标的边界位置的。这里的每个列表示不同 head 的输出。

在这里插入图片描述

为什么会出现这样的问题呢,作者就进行了 DETR 结构的重新思考,并且提出了对之后的多篇工作影响很大的 content query/key 和 spatial query/key 的概念

在 DETR 中,DETR decoder 的 cross-attention 如下图所示(图来自 DETR 原文,cross-attention 是右边分支中间位置的模块),是对 encoder 的输出和 object query 做交叉注意力的模块,会有三个输入:

  • queries:每个 query 都是 decoder 第一层 self-attention 的输出( content query )+ object query( spatial query ),这里的 object query 就是 DETR 中提出的概念,每个 object query 都是候选框的信息,经过 FFN 后能输出位置和类别信息(本文 object query 个数 N 为 300)
  • keys:每个 key 都是 encoder 的输出特征( content key ) + 位置编码( spatial key )构成的,可以看 k 的下边是有一个 + 号的
  • values:只有来自 encoder 的输出

在这里插入图片描述

基于上面提到的 content 和 spatial 的概念,我们就可以继续往下看了,作者本文的动机来源于这么一个结论:

  • 在 cross-attention 中,content embedding 贡献很大,spatial embedding 贡献很小
  • 假设移除 key 中的 positional embedding 和第二层的 object query(但是第一层的 object query 是一定不能移除的哦!!!),AP 只会下降很小的值

影响究竟有多小呢:

  • 训练 50 个 epoch 的时候,不使用 spatial embedding 的话,AP 从 34.9 降到了 34.0
  • 训练 300 个 epoch 的时候,不使用 spatial embedding 的话,AP 下降 1.4

所以 DETR 是把 content key/query 和 spatial key/query 合在一起训练的,既然 spatial 影响不大,那就说明 content 的影响很大,需要要求 content 质量很高,才能得到较好的效果,当 content 质量不高的时候,spatial 再挣扎也没用。所以说,DETR 收敛慢的原因很大程度上是来源于 content 的质量太低了,且难以提升,所以 DETR 需要更大的训练 epoch 才能得到较好的效果,也进一步说明边界位置的定位能力其实来源于 content 特征,要想更快的提升模型对边界位置的定位能力,要能更快的让 content 特征优化得当才行。

说到这里,Conditional DETR 就提出了自己的方法,主要解决的就是 decoder 中的 cross-attention 的 content 难以学习的问题,Conditional DETR 从前一个 decoder 的输出为每个 query 都学习一个 conditional spatial embedding,来为 decoder 的 cross-attention 形成一个所谓的 【 conditional spatial query】。

二、Conditional DETR 的具体实现

2.1 框架结构

Conditional DETR 的方法沿用了 DETR 的整体流程,包括 CNN backbone,transformer encoder,transformer decoder, 以及 object class 和 box 位置的预测器。Encoder 和 Decoder 各自由6个相同的 layer 堆叠而成。Conditional DETR 相对于 DETR 的改动主要在 decoder 的 cross-attention 部分。

改动后的 decoder 的一层如图 3 所示,有三层:

  • 第一层: 输入为上一层 decoder 的输出,self-attention 过程用于去除冗余的预测,self-attention 就是来求 object query embedding 之间的关系,在每个特征都互相交流之后,就能知道每个 query 在预测什么样的特征,从而避免预测的特征冗余了,尽量都预测不一样的特征才好。
  • 第二层:cross-attention 层,用于聚合 encoder 的输出和 decoder 内部的特征,进一步加强特征
  • 第三层:FFN 层

在这里插入图片描述

Box 位置的回归:

在这里插入图片描述

  • f 是 decoder embedding
  • b 是 4d 向量 [x, y, w, h]
  • sigmoid 用于将输出归一化为 [0,1]
  • s 是参考点的未经归一化的 2D 坐标,在 DETR 中是 (0,0)

类别 预测:

在这里插入图片描述

conditional DETR 的 cross-attention 主要的作用:

  • 定位目标区域,也就是定位 bbox 的 4 个边界位置来定位,识别框内的目标区域用于分类
  • 通过引入 conditional spatial queries ,能够提升定位能力,且加速训练过程

2.2 DETR 的 cross-attention 和 Conditional DETR 的 cross-attention 对比

1、DETR

在 DETR 中,DETR decoder 的 cross-attention 如下图所示(图来自 DETR 原文,cross-attention 是右边分支中间位置的模块),是对 encoder 的输出和 object query 做交叉注意力的模块,会有三个输入:

  • queries:每个 query 都是 decoder 第一层 self-attention 的输出( content query c q c_q cq)+ object query( spatial query p q p_q pq ),这里的 object query 就是 DETR 中提出的概念,每个 object query 都是候选框的信息,经过 FFN 后能输出位置和类别信息(本文 object query o q o_q oq个数 N 为 300)
  • keys:每个 key 都是 encoder 的输出特征( content key c k c_k ck ) + 位置编码( spatial key p k p_k pk)构成的,可以看 k 的下边是有一个 + 号的
  • values:只有来自 encoder 的输出

因为所有的 spatial 的 content query/key 都是加在一起的,所以 attention 时候操作过程如下:

在这里插入图片描述

2、Conditional DETR

conditional cross-attention 方法将 content 和 spatial 分开了,使用如下的方式进行,让 spatial query 聚焦于 spatial attention,content query 聚焦于 content attention

在这里插入图片描述

三、效果

在这里插入图片描述

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/jiaoyangwm/article/details/132431252