结合可变形注意力的视觉Transformer

fig1

前置内容

首先要理解 Q , K , V Q,K,V Q,K,V S o f t m a x ( Q K T ) V Softmax(QK^{T})V Softmax(QKT)V假设window内的patch总数为3,则有:
fig0

可以想象,对于 3 × 3 3\times 3 3×3注意力分布,即 Q K T QK^{T} QKT,行代表query,列代表key, Q K T [ i , j ] QK^{T}[i,j] QKT[i,j]为patch i i i (query)和patch j j j (key)的相似度。对于SwinT中的相对位置偏置信息: S o f t m a x ( Q K T + B ) V Softmax(QK^{T}+B)V Softmax(QKT+B)V其实 Q K T + B QK^{T}+B QKT+B就是在原本的注意力分布基础上引入相对位置设计下的注意力偏置,例如让空间接近的patch关系更密切。

对于 Q K T V QK^{T}V QKTV的第1行向量的第 k k k个元素,相当于将 patch 1 (query)与window内所有patch (keys)的相似度element-wise乘到 V V V的所有patch的第 k k k维值上,然后求和作为patch 1的第 k k k维值。

从上面过程看出,一个query对应着多个key或value,目标是求出当前query的修正后表示。

摘要

Transformer最近在各种视觉任务上表现出了优异的性能。对于巨大的甚至是全局性的感受野赋予Transformer模型比CNN模型更高的表现力。然而,仅仅扩大感受野也会引起一些担忧。一方面,在ViT中使用密集的注意力会导致过度的内存和计算成本,并且特征可能会受到感兴趣区域之外的不相关部分的影响。另一方面,在PVT或Swin Transformer中采用的稀疏注意力可能会限制建立long range关系模型的能力。为了缓解这些问题,我们提出了一种新的可变形的自注意模块,其中以数据相关的方式选择自注意中的键-值对的位置。这种灵活的方案使自注意模块能够专注于相关区域并捕获其特征。在此基础上,我们提出了Deformable Attention Transformer,这是一种用于图像分类和密集预测任务的具有可变形注意力的通用主干模型。大量实验表明,我们的模型在综合基准上取得了持续改进的结果。

1.Introduction

Transformer最初是用来解决自然语言处理任务的。它最近在计算机视觉领域显示出巨大的潜力。ViT的先驱工作是将多个Transformer block堆叠在一起,以处理非重叠的图像patch(即视觉token)序列,从而产生了一种用于图像分类的无卷积模型。与CNN对应模型相比,基于Transformer的模型具有更大的感受野,并且擅长建模长期依赖关系。然而,视觉识别中的过度注意力是一把双刃剑,具有多重缺陷。具体而言,每个query patch要处理过多的keys(query与key计算相似度得到概率分布,修正value),导致计算量大,收敛速度慢,增加了过度拟合的风险。

为了避免过度的注意力计算,现有的研究已经利用精心设计的高效注意力模式来降低计算复杂性。作为其中两种具有代表性的方法,Swin Transformer采用基于窗口的局部注意力来限制计算每个在局部窗口中,而Pyramid Vision Transformer(PVT)对key和value的特征映射进行下采样以节省计算。手工制作的注意力pattern虽然有效,但数据内的对象变换莫测,SwinT可能不是最优的,可能会忽视相关的键/值,而计算了不太相关的键/值。

理想情况下,可以预期给定query的候选key/value集是灵活的,并且能够适应每个单独的输入,从而可以缓解手工构建的稀疏注意模式中的问题。事实上,在CNN的文献中,学习卷积filter的可变形感受野在数据相关的基础上有选择地关注更多信息区域是有效的。最著名的工作是可变形卷积网络,它在许多具有挑战性的视觉任务上取得了令人印象深刻的结果。这促使我们探索视觉Transformer中的可变形注意力pattern。然而,这种想法的朴素实现导致了不合理的高内存/计算复杂性:可变形偏移量引入的开销是patch数量的二次方。因此,尽管最近的一些工作研究了Transformer中可变形机制的概念,但由于计算成本高,没有人将其视为构建像DCN这样强大骨干网络的基本模块。相反,它们的可变形机制要么在检测头中采用,要么用作预处理层,用于为后续骨干网络采样patch。

在本文中,我们提出了一个简单而高效的可变形的自注意模块,利用该模块构造了一个强大的金字塔主干,称为Deformable Attention Transformer(DAT),用于图像分类和各种密集预测任务。与DCN学习整个特征图中不同像素的偏移量不同,我们建议学习几组查询无关(query agnostic)偏移量,以将键和值移动到重要区域(如图1d所示),即注意力在不同的query下得到相同的注意力模式。这种设计既保持了线性空间的复杂性,又将可变形的注意力模式引入了Transformer主干。具体来说,对于每个注意力模块,首先为参考点(reference points)生成统一的网格,这些网格在输入数据中是相同的。然后,偏移网络将query特征作为输入,并为所有参考点生成相应的偏移。通过这种方式,候选key/value被转移到重要区域,从而以更高的灵活性和效率增强了原有的自注意模块,以捕获更多信息的特征。

综上所述,我们的贡献如下:我们提出了第一个用于视觉识别的可变形自注意backbone,其中依赖于数据的注意力模式具有更高的灵活性和效率。在ImageNet、ADE20K和COCO上进行的大量实验表明,我们的模型在图像分类的top1-Acc方面,始终优于包括Swin Transformer在内的竞争性baseline,在语义分割的mIoU方面,我们的模型的优势为+0.7,在box AP和mask AP的目标检测方面,我们模型的优势为+1.1。小物体和大物体的优势更加明显,差值为+2.1。

fig2

  • 图1:比较DAT与SwinT以及CNN模型中的DCN。红星和蓝星表示不同的query patch,带实线边界的掩码表示query所涉及的区域。
  • a:ViT对query给予充分关注。query patch与其他所有patch计算;
  • b:SwinT将与query patch计算的patch限制在局部窗口内;
  • c:DCN为每个query学习不同的变形点;
  • d:DAT学习两个query的共享的变形点
  • a和b是与数据无关的方式,c和d是依赖数据的方式。

2.Related Work

2.1.Transformer backbone

自引入ViT以来,改进的重点是学习密集预测任务的多尺度特征和有效的注意力机制。这些注意力机制包括windowed attention、全局token、focal attention和动态的token size。最近,基于卷积的方法被引入到ViT中。其中,已有研究侧重于用卷积运算补充Transformer模型。CvT在token化过程中采用卷积,并利用空洞卷积来降低自注意的计算复杂度。具有卷积backbone的ViT建议在早期阶段添加卷积,以实现更稳定的训练。CSwin Transformer采用了基于卷积的位置编码技术,并显示了对下游任务的改进。许多基于卷积的技术都有可能应用于DAT之上,以进一步提高性能。

2.2.Deformable CNN and attention

可变形卷积是一种处理以输入数据为条件的灵活变换空间位置的机制。最近,它已应用于ViT。Deformable DETR通过为CNN backbone顶部的每个query选择少量的key,改进了DETR的收敛性。由于缺少key限制了表示能力,其可变形注意力不适合用于作为特征提取的视觉主干。此外,Deformable DETR中的注意力来自简单学习的线性投影,并且query token之间不共享key。DPT和PS ViT构建可变形模块以优化视觉token。具体而言,DPT提出了一种可变形patch embedding方法来细化各个stage的patch,而PS ViT在ViT主干之前引入了一个空间采样模块来改进视觉token。它们都没有将变形注意力纳入视觉中枢。相比之下,我们的可变形注意力采用了一种强大而简单的设计来学习视觉token之间共享的一组全局key,可以作为各种视觉任务的通用主干。我们的方法也可以被视为一种空间自适应机制,这在各种工作中都被证明是有效的。

3.Deformable Attention Transformer

3.1.Preliminaries

首先,我们回顾ViT中的注意力机制。选取一个flattened的特征图 x ∈ R N × C x\in R^{N\times C} xRN×C作为输入, N N N为patch数量, C C C为每个patch展平后的维数,具有 M M M个head的多头注意力(MHSA)block的计算为: q = x W q , k = x W k , v = x W v q=xW_{q},k=xW_{k},v=xW_{v} q=xWq,k=xWk,v=xWv z ( m ) = σ ( q ( m ) k ( m ) T / d ) v ( m ) , m = 1 , . . . , M z^{(m)}=\sigma(q^{(m)}k^{(m)T}/\sqrt{d})v^{(m)},m=1,...,M z(m)=σ(q(m)k(m)T/d )v(m),m=1,...,M z = c o n c a t ( z ( 1 ) , . . . , z ( M ) ) W o z=concat(z^{(1)},...,z^{(M)})W_{o} z=concat(z(1),...,z(M))Wo其中, σ ( ⋅ ) \sigma(\cdot) σ()表示softmax函数, d = C / M d=C/M d=C/M为每个head的维数。 z ( m ) z^{(m)} z(m)为第 m m m个head的输出embedding, q ( m ) , k ( m ) , v ( m ) ∈ R N × d q^{(m)},k^{(m)},v^{(m)}\in R^{N\times d} q(m),k(m),v(m)RN×d表示query,key,value的embedding。 W q , W k , W v , W o ∈ R C × C W_{q},W_{k},W_{v},W_{o}\in R^{C\times C} Wq,Wk,Wv,WoRC×C为投影矩阵。为了建立Transformer block,通常采用具有两个线性变换和GELU激活的MLP block来提供非线性变换。

通过normalization L N LN LN 和残差连接,第 l l l层的Transformer block为: z l ′ = M H S A ( L N ( z l − 1 ) ) + z l − 1 z_{l}'=MHSA(LN(z_{l-1}))+z_{l-1} zl=MHSA(LN(zl1))+zl1 z l = M L P ( L N ( z l ′ ) ) + z l ′ z_{l}=MLP(LN(z_{l}'))+z_{l}' zl=MLP(LN(zl))+zl

3.2.可变形注意力

现有的层次视觉Transformer,尤其是PVT和SwinT试图解决过度注意力(excessive attention)的挑战。前者的下采样技术会导致严重的信息丢失(Transformer在传递特征图的过程中不会下采样,PVT依靠对特征图下采样减少patch数量,并形成多尺度特征),后者的注意力shifted会限制感受野的增长,这限制了建模大型object的潜力。因此,我们需要依赖于数据的稀疏注意力来灵活地建模相关特征,从而产生DCN中首次提出的可变形机制。然而,在Transformer模型中简单地实现DCN是一个非常重要的问题。在DCN中,特征映射上的每个元素分别学习其偏移量,其中 H × W × C H×W×C H×W×C特征映射上的 3 × 3 3×3 3×3可变形卷积具有 2 × 9 × H W 2\times 9\times HW 2×9×HW的空间复杂度(回顾第三十六课.可变形卷积)。

对于可变形注意力,具体而言,我们提出了变形注意,在特征图中重要区域的指导下,有效地建模token之间的关系。这些聚焦区域由多组变形采样点(deformed sampling points)确定,这些变形采样点通过偏移网络(offset network)从query中学习。我们采用双线性插值对特征进行采样,然后将采样的特征反馈给key和value投影,得到变形的key和value。最后,应用标准的多头注意力处理采样后的key的query,并从变形value中聚合特征。此外,变形点的位置提供了更强大的相对位置偏差,以促进可变形注意的学习,这将在以下章节中讨论。

可变形注意力模块

fig3

  • 图2:可变形注意力的演示。
  • a:可变形注意力的流程。在左侧,一组参考点(reference points)统一放置在特征图上,其偏移量由偏移网络从query中学习。然后,根据变形点(Deformed Points,参考点加上偏移)从采样的特征投影得到变形的key和value,我们基于变形点计算相对位置偏差(注意不是前面说的偏移量),以增强输出变换特征的多头注意力。我们只显示了4个参考点(也就是1个query对应4个参考点),以便进行清晰的演示,实际实现中还有更多的参考点。
  • b:显示偏移生成网络的详细结构,并用特征图的大小进行标记。DWconv为Depth-wise Conv,通道可分离卷积减少了卷积的参数。

如图2a所示,输入特征图 x ∈ R H × W × C x\in R^{H\times W\times C} xRH×W×C,会自动生成一组均匀的网格点 p ∈ R H G × W G × 2 p\in R^{H_{G}\times W_{G}\times 2} pRHG×WG×2,具体来说,网格大小为 r × r r\times r r×r,则 H G = H / r , W G = W / r H_{G}=H/r,W_{G}=W/r HG=H/r,WG=W/r参考点的值是线性间隔的二维坐标,其实就是每个网格的左上角,但坐标是相对于网格的 ( 0 , 0 ) , . . . , ( H G − 1 , W G − 1 ) (0,0),...,(H_{G}-1,W_{G}-1) (0,0),...,(HG1,WG1)特别注意,参考点的坐标是相对于网格的,不是相对于特征图patch的。通常网格数量应该小于patch数量,即参考点数量应该小于patch数量,即 H G W G < N H_{G}W_{G}<N HGWG<N

然后我们将参考点坐标进行标准化到 [ − 1 , + 1 ] [-1,+1] [1,+1]之间,其中, ( − 1 , − 1 ) (-1,-1) (1,1)代表左上角网格的左上角坐标 ( + 1 , + 1 ) (+1,+1) (+1,+1)代表右下角网格的左上角坐标。为了获得每个参考点的偏移量,将特征图线性投影到query tokens q = x W q q=xW_{q} q=xWq,然后将其输入到轻量级的网络 θ o f f s e t ( ⋅ ) \theta_{offset}(\cdot) θoffset()生成偏移 Δ p = θ o f f s e t ( q ) ∈ R H G , W G , 2 \Delta p=\theta_{offset}(q)\in R^{H_{G},W_{G},2} Δp=θoffset(q)RHG,WG,2。为了训练的稳定,我们调整 Δ p \Delta p Δp乘以一些预定义的系数 s s s以防止偏移过大,即: Δ p = s ⋅ t a n h ( Δ p ) \Delta p=s\cdot tanh(\Delta p) Δp=stanh(Δp)然后在变形点的区域对特征进行采样,并生成变形后的key和value q = x W q , k ~ = x ~ W k , v ~ = x ~ W v , x ~ = ϕ ( x ; p + Δ p ) q=xW_{q},\tilde{k}=\tilde{x}W_{k},\tilde{v}=\tilde{x}W_{v},\tilde{x}=\phi(x;p+\Delta p) q=xWq,k~=x~Wk,v~=x~Wv,x~=ϕ(x;p+Δp)其中, k ~ , v ~ \tilde{k},\tilde{v} k~,v~表示变形后key和value的embedding。采样函数 ϕ ( ⋅ ; ⋅ ) \phi(\cdot;\cdot) ϕ(;)是可微的双线性插值: ϕ ( z ; ( p x , p y ) ) = ∑ ( r x , r y ) g ( p x , r x ) g ( p y , r y ) z [ r y , r x , : ] \phi(z;(p_{x},p_{y}))=\sum_{(r_{x},r_{y})}g(p_{x},r_{x})g(p_{y},r_{y})z[r_{y},r_{x},:] ϕ(z;(px,py))=(rx,ry)g(px,rx)g(py,ry)z[ry,rx,:]其中, g ( a , b ) = m a x ( 0 , 1 − ∣ a − b ∣ ) g(a,b)=max(0,1-|a-b|) g(a,b)=max(0,1ab) ( r x , r y ) (r_{x},r_{y}) (rx,ry)为特征图 z z z的位置索引(在实数空间 z ∈ R H , W , C z\in R^{H,W,C} zRH,W,C下)。由于 g g g仅在最接近 ( p x , p y ) (p_{x},p_{y}) (px,py)的4个坐标上为非零(注意:在这个函数中 ( p x , p y ) (p_{x},p_{y}) (px,py)已经反归一化为整数,由于 g ( a , b ) = m a x ( 0 , 1 − ∣ a − b ∣ ) g(a,b)=max(0,1-|a-b|) g(a,b)=max(0,1ab),有效 ( r x , r y ) (r_{x},r_{y}) (rx,ry)距离 ( p x , p y ) (p_{x},p_{y}) (px,py)最大为 1 1 1),所以它将上式简化为这4个距离变形点最近位置的网格的特征向量加权平均值见图2a的四个方形区域)。注意 p p p r r r反归一化后为整数,这才能确保索引到特征图的向量。

回顾之前说的,每个query patch对应 H G W G H_{G}W_{G} HGWG个参考点,在对参考点进行偏移时,我们只用卷积网络生成了一个偏移量的张量,相当于是不同query的变形点都直接共享了,(见图1d)

x ~ ∈ R H G W G × C \tilde{x}\in R^{H_{G}W_{G}\times C} x~RHGWG×C的形状与 x ∈ R N × C x\in R^{N\times C} xRN×C不一样, x ~ \tilde{x} x~中的patch i i i是某个query patch(query patch共 N N N个)的参考点 i i i(参考点共 H G W G H_{G}W_{G} HGWG个)经过变形后的patch embedding。

后面的操作与现有方法类似,我们在 q , k , v q,k,v q,k,v上执行多头注意力,并采用相对位置偏移量 R R R,注意力头的输出公式如下: z ( m ) = σ ( q ( m ) k ~ ( m ) T / d + ϕ ( B ^ ; R ) ) v ~ ( m ) z^{(m)}=\sigma(q^{(m)}\tilde{k}^{(m)T}/\sqrt{d}+\phi(\widehat{B};R))\tilde{v}^{(m)} z(m)=σ(q(m)k~(m)T/d +ϕ(B ;R))v~(m)其中, ϕ ( B ^ ; R ) ∈ R H W × H G W G \phi(\widehat{B};R)\in R^{HW\times H_{G}W_{G}} ϕ(B ;R)RHW×HGWG对应位置embedding。详细信息将在本节后面进行解释。我们将每个head的特征concat在一起,并通过 W o W_{o} Wo投影,以获得最终输出 z z z。回顾前置内容,Attention的目的是为每个query重新计算其表示,所以 z ( m ) ∈ R H W , − 1 z^{(m)}\in R^{HW,-1} z(m)RHW,1

个人理解:可变形注意力的意义

比如在重新计算某个Query比如 q u e r y x query_x queryx的表示 q u e r y x ′ query_x' queryx的过程中,我们把注意力计算限制在使用 q u e r y x query_x queryx乘变形点对应的key上,得到注意力分布,然后乘变形点对应的value,得到 q u e r y x ′ query_x' queryx。也就是说, q u e r y x ′ query_x' queryx是融合可变形感受野的embedding得到的。
fig001


偏移的计算
可以看出,在计算偏移时,其实DAT和DCN的做法是一样的,还是依赖卷积;

SwinTransformer中,关于Attention中的相对位置偏差
引入位置的形式为: S o f t m a x ( Q K T + B ) V Softmax(QK^{T}+B)V Softmax(QKT+B)V其中, Q , K , V ∈ R M 2 × d Q,K,V\in R^{M^{2}\times d} Q,K,VRM2×d B ∈ R M 2 × M 2 B\in R^{M^{2}\times M^{2}} BRM2×M2为相对位置偏差。

相对位置偏差对每对query和key之间的相对位置进行编码,利用空间信息增强注意力。在SwinT中,假设window size M = 7 M=7 M=7,则共有49个patch,所以有 4 9 2 49^{2} 492个相对位置,每个相对位置有两个索引(对应 x x x y y y方向),每个索引的取值范围是 [ − 6 , 6 ] [-6,6] [6,6]

  • 第0行相对第6行, x = 0 − 6 = − 6 x=0-6=-6 x=06=6,第6行相对第0行, x = 6 − 0 = 6 x=6-0=6 x=60=6
  • 第0列相对第6列, y = 0 − 6 = − 6 y=0-6=-6 y=06=6,第6列相对第0列, y = 6 − 0 = 6 y=6-0=6 y=60=6

此时,构建的相对位置张量relative_coord的shape为 ( 49 , 49 , 2 ) (49,49,2) (49,49,2)relative_coord[i, j, :]表示window内第 i i i个patch相对第 j j j个patch的坐标。

由于此时索引取值范围中包含负值,可分别在每个方向上加上6,使得索引取值从0开始。此时,索引取值范围为 [ 0 , 12 ] [0,12] [0,12]

这个时候可以参数化一个shape为 [ 13 , 13 ] [13,13] [13,13]Table B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \widehat{B}\in R^{(2M-1)\times (2M-1)} B R(2M1)×(2M1),则当相对位置为 ( i , j ) (i,j) (i,j)时,有 b = B ^ [ i , j ] b=\widehat{B}[i, j] b=B [i,j]。其中, i , j i,j i,j的取值范围是 [ 0 , 12 ] [0,12] [0,12] B B B中的值 b b b来自 B ^ \widehat{B} B

  • 举个例子:window内,49个patch,要计算第1个patch相对第6个patch的位置偏差,即 B [ 0 , 5 ] B[0,5] B[0,5] B ∈ R 49 × 49 B\in R^{49\times 49} BR49×49保存了patch i i i相对patch j j j的位置偏移注意力,首先获取relative_coord[0,5,:]得到元组(x,y)(即第1个patch相对第6个patch的相对位置),然后查表有 B [ 0 , 5 ] = B ^ [ x , y ] B[0,5]=\widehat{B}[x,y] B[0,5]=B [x,y]

由于使用的是 multi head self-attention,所以Table[i, j]的值应该是一个维度为num_heads的一维向量。


Offset generation

如前所述,offset的生成采用一个轻量级子网络,该子网络使用query特征并输出参考点的offset。考虑到每个参考点覆盖一个局部 s × s s×s s×s区域( s s s是偏移量的最大值),生成网络应具有局部特征感知能力,以得到合理的偏移量。因此,我们将子网络实现为具有非线性激活的两个卷积模块,如图2b所示。输入特征首先通过 5 × 5 5×5 5×5深度可分离卷积来捕获局部特征。然后,采用GELU激活和 1 × 1 1×1 1×1卷积得到二维偏移量。

Offset groups

为了促进变形点的多样性,我们遵循MHSA的思想,将特征通道划分为 G G G组。每组特征使用共享子网络分别生成相应的偏移。实际上,注意力模块的头部编号 M M M被设置为offset groups G G G 大小的倍数,确保将多个注意力头分配给一组变形的key和value。Offset groups用于生成多组偏移量。

Deformable relative position bias

相对位置偏差对每对query和key之间的相对位置进行编码,利用空间信息增强注意力。考虑形状为 H × W H\times W H×W的特征图,其相对坐标偏差分别在 [ − H , H ] [−H,H] [H,H] [ − W , W ] [−W,W] [W,W]上。在Swin Transformer中,构造相对位置偏差表 B ^ ∈ R ( 2 H − 1 ) × ( 2 W − 1 ) \widehat{B}\in R^{(2H−1)×(2W−1)} B R(2H1)×(2W1)以获得相对位置偏差 B B B参考上面内容:SwinT中,关于Attention中的相对位置偏差)。由于我们的可变形注意力具有连续型的变形点位置信息,我们计算值在 [ − 1 , + 1 ] [-1,+1] [1,+1]内的相对位置 R R R(即relative_coord,注意是patch对应变形点之间的相对位置),然后在参数化的相对位置偏差表 B ^ ∈ R ( 2 H − 1 ) × ( 2 W − 1 ) \widehat{B}\in R^{(2H-1)\times (2W-1)} B R(2H1)×(2W1)中插值 ϕ ( B ^ ; R ) ∈ R H W × H G W G \phi(\widehat{B};R)\in R^{HW\times H_{G}W_{G}} ϕ(B ;R)RHW×HGWG得到相对位置偏差。

3.3.模型架构

fig4

  • 图3:模型架构, N 1 N_1 N1 N 4 N_4 N4是堆叠的连续局部注意力和移位窗口/可变形注意块的数量。 k k k s s s表示patch embedding中卷积层的kernel size和步长。

如图3所示,形状为 H × W × 3 H×W×3 H×W×3的输入图像首先通过步长为4的 4 × 4 4×4 4×4非重叠卷积得到embedding,然后通过归一化层得到 ( H / 4 ) × ( W / 4 ) × C (H/4)×(W/4)×C (H/4)×(W/4)×C patch embedding。主干包括4个stage,步长逐渐增大,旨在构建层次特征金字塔。在两个连续的stage之间,有一个不重叠的 2 × 2 2×2 2×2卷积,步幅为2,以对特征图进行下采样,从而将空间大小减半,并将特征尺寸加倍。在分类任务中,我们首先对最后阶段输出的特征图进行规范化,然后采用具有混合特征的线性分类器来预测logit。在目标检测、实例分割和语义分割任务中,DAT在综合视觉模型中起着骨干作用,用于提取多尺度特征。我们为每个阶段的特征添加了一个规范化层,然后将其输入以下模块,如目标检测中的FPN或语义分割中的解码器

在DAT的第三和第四stage,我们引入了连续的局部注意力(只在window内编码)和可变形注意力块。首先通过基于窗口的局部注意力对特征图进行处理,以局部聚集信息,然后通过可变形注意块来建模局部增强token之间的全局关系。这种带有局部和全局感受野的注意块的交替设计有助于模型学习embedding。

由于前两个stage主要学习局部特征(初期阶段),因此这些早期阶段的可变形注意力不太适用。此外,前两个阶段中的key和value具有相当大的空间大小,这大大增加了变形注意力中双线性插值的计算开销。因此,为了在模型容量和计算负担之间达成平衡,我们只在第三和第四阶段放置可变形注意力(在较高级语义上进行可变形注意力计算),并采用在Swin Transformer中的移位窗口注意力,以便在早期阶段有更好的表示。我们构建了三种DAT变体,以便与其他视觉模型进行比较。详细的体系结构如表1所示。
fig5

  • 表1:模型架构。 N i N_i Ni:stage i i i的块数。 C C C:通道数。window size:局部注意模块中的区域大小。heads:DMHA中的head数。groups:DMHA中的Offset group数。

4.Experiments

实验结果

我们在几个数据集上进行了实验,以验证我们提出的DAT的有效性。我们展示了我们在ImageNet-1K分类、COCO目标检测和ADE20K语义分割任务上的结果。
fig6

  • 表2:在ImageNet-1K分类任务中,DAT与其他视觉主干在速度,参数量和精度方面的比较。
    fig7
  • 表3:基于RetinaNet的COCO目标检测结果。RGB输入图像分辨率为1280×800。

fig8

  • 表4:COCO目标检测和实例分割的结果。RGB输入图像分辨率为1280×800。

fig9

  • 表5:语义分割结果。RGB输入图像的分辨率为512×2048。

可视化结果

为了验证可变形注意力的有效性,我们使用了与DCNs类似的机制,通过传播他们的注意力权重实现可视化。如图5所示,我们的可变形注意力将重要的key放在前景中,这表明它关注对象的重要区域,这支持了图1所示的假设。更多可视化内容见图6和7。
fig10

  • 图5:COCO验证集上重要的key可视化。橙色圆圈显示了在多头传播注意力中得分最高的key。半径越大表示得分越高。请注意,右下角的图像显示的是一个人挥动球拍打网球。

fig11

  • 图6:在DAT的第3阶段(第一行)和第4阶段(第二行),在COCO上可视化可变形注意力中的变形点位置。

fig12

  • 图7:COCO验证集上的可视化。red star表示一个query,橙色点是最后一层注意力得分较高的key。第一行和第三行中的图像描述了我们的DAT注意力,第二行和第四行中显示了SwinTransformer的注意力。

个人总结

  • 比如在重新计算某个Query比如 q u e r y x query_x queryx的表示 q u e r y x ′ query_x' queryx的过程中,我们把注意力计算限制在使用 q u e r y x query_x queryx乘变形点对应的key上,得到注意力分布,然后乘变形点对应的value,得到 q u e r y x ′ query_x' queryx。也就是说, q u e r y x ′ query_x' queryx是融合可变形感受野的embedding得到的。
    fig002
  • 在计算偏移时,其实DAT和DCN的做法是一样的,还是依赖卷积,但是做出了解释,偏移应该是局部的,局部的偏移量更多情况下才是合理的偏移量(一张图像的对象通常不会占据整个图像,而且对象数量通常不止一个),所以使用卷积,快速且符合实际场景。
  • DCN是局部偏移,但是有可能会感知到错误的上下文,或许是因为卷积导致的局部信息感知,不能看到完整的上下文,这是DCN学习上就存在的瓶颈。
  • DAT做了先验假设,即reference points,并且参考点被query共享,对每个query来说,都只与变形点(变形点由参考点偏移得到)计算注意力分布并更新query的表达。有效之处在于:1.参考点是全局的,2.变形点是用偏移明确修正到目标对象上的全局信息,这样计算后的query patch表达与图像中的目标对象更加密切相关,一定程度上削弱了背景噪声。第一点引入了全局信息,但这是Transformer本身存在的优势,主要在于第二点,实现了对象形状自适应的注意力限制,去除噪声,相当于在感兴趣区域上计算注意力。

猜你喜欢

转载自blog.csdn.net/qq_40943760/article/details/125091334