Multiscale Vision Transformers 论文详解

目录

Abstract

1. Introduction

2. Related Work 

3. Multiscale Vision Transformer (MViT) 

3.1. Multi Head Pooling Attention 

3.2. Multiscale Transformer Networks

3.3. Network instantiation details


论文地址:https://arxiv.org/abs/2104.11227

代码地址:https://github.com/facebookresearch/SlowFast

个人感觉有两个贡献点:

  • 利用 pooling 操作实现下采样,pooling 本身就是那种模仿视觉的算子,用的比较巧妙,让我引发了一些关于多尺度的思考,还节省了参数量
  • 不同的 stage 使用不同核大小的 pooling ,这些 stage 连起来就像多尺度金字塔(作者那么说的)。

Abstract

        我们提出了 Multiscale Vision Transformers(MViT)用于视频和图像识别。MViT 是多尺度特征层次结构和Transformer的结合。MViT 有几个通道分辨率尺度块(channel-resoluation scale stages)。从输入分辨率和小通道维度开始,这些stages扩展通道容量,同时降低空间分辨率。这创建了一个多尺度特征金字塔,早些的层在高空间分辨率下运行以模拟简单的低级视觉信息,而更深层在空间粗糙但复杂的高维特征上运行。我们评估了这种MViT,用于各种视频识别任务中密集型任务,它优于依赖大规模外部预训练并且在计算和参数比我们高出 5-10 倍的 ViT。我们删除了时间维度并将我们的模型应用于图像分类,它也优于先前 ViT 的表现。

简单概括:MViT引入了多尺度特征金字塔结构,解决了视频识别任务中目标密集型任务。而且参数量与推理速度要比ViT少很多,在图片分类任务上的表现也比ViT的要好。

1. Introduction

        我们从计算机视觉神经网络模型的知识史开始。基于对猫和猴子视觉皮层的研究,Hubel 和 Wiesel [55] 发明了视觉通路的层次模型,其中神经元在较低区域(例如 V1)对定向边缘和条形等特征做出反应,而在较高区域则对更具体刺激做出反应。Fukushima 提出了 Neocognitron [32],这是一种由 Hubel 和 Wiesel 的层次结构明确驱动的模式识别神经网络架构。他的模型具有简单单元和复杂单元的交替层,因此包含了下采样,平移不变性和卷积结构。 LeCun 等人。 [65] 采取了使用反向传播来训练该网络的权重。但是已经建立了视觉处理层次结构的主要方面:(i)随着处理层次的增加,空间分辨率降低(ii)增加了不同的通道数量,这些通道对应更多特殊的特征。

        在并行开发中,计算机视觉社区开发了多尺度处理,有时称为“金字塔”策略,Rosenfeld 和 Thurston [85]、Burt 和 Adelson [8]、Koenderink [61] 等都是关键论文。有两个动机(i)通过在较低分辨率下工作来减少计算需求,以及(ii)在较低分辨率下更好地理解“上下文”,然后可以指导高分辨率下的处理。

        Transformer [98] 架构允许学习在集合上定义的任意函数,并且在语言理解 [26] 和机器翻译 [7] 等序列任务中取得了可扩展的成功。从根本上说,transformer 使用具有两个基本操作的块。首先,是用于建模元素间关系的注意力操作 [4]。其次,是多层感知器 (MLP),它对元素内的关系进行建模。将这些操作与归一化 [2] 和残差连接 [44] 交织在一起,可以让transformers泛化到各种各样的任务。

        最近,transformer 已应用于关键的计算机视觉任务,例如图像分类。本着架构普适主义的精神,vision transformer [25, 95] 在各种数据上或计算状态上都媲美卷积模型。通过只有第一层使用 2D 卷积调整输入,然后是一堆 transformer blocks,vit 旨在展示 transformer 架构的强大功能,使用很少的归纳偏置。

        在本文中,我们的目的是将多尺度特征层次结构与 transformer 模型联系起来。我们假设分辨率和通道调整的基本视觉原理对于跨各种视觉识别任务的 transformer 模型可能是有益的。​​​​​​​

        我们提出MViT,一种用于对图像和视频等视觉数据进行建模的 transformer 架构。考虑如图 1 所示的输入图像。与在整个网络中保持恒定通道容量和分辨率的传统 transformers 不同,MViT具有多个通道分辨率“尺度”阶段。这在 transformer 网络内部创建了一个多尺度的特征激活金字塔,有效地将 transformers 的原理与多尺度特征层次结构联系起来。

        我们的概念性想法为 ViT 模型提供了有效的设计优势。由于轻量级通道容量,我们架构的前面的层可以在高空间分辨率下运行以模拟简单的低级视觉信息。反过来,更深层可以有效地关注空间粗糙但复杂的高级特征来建模视觉语义。我们的 MViT 的可以很好的处理极其稠密的视觉信号,这种现象对于视频中捕获的时空视觉信号更为明显。

        我们设计的一个值得注意的好处是视频多尺度模型中存在强烈的隐式时间偏差(implicit temporal bias)。我们表明,在自然视频上训练的 ViT  [25] 在使用随机帧的视频上进行测试时不会出现性能衰减。这表明这些模型没有有效地使用时间信息,而是严重依赖图像特征。相比之下,当在打乱的帧上测试我们的 MViT 模型时,我们观察到显着的精度衰减,表明对时间信息的大量使用。

        我们在本文中的重点是视频识别,我们为视频任务设计和评估 MViT(Kinetics [59,10]、Charades [86]、SSv2 [38] 和 AVA [39])。与同时期的 ViTs [78,6,1] 相比,MViT 提供了显着的性能增益,无需任何外部预训练数据。

        在图 A.4 中,我们展示了当改变 MViT 中使用的时间片段数量时,视频推理的计算/准确性权衡。纵轴显示 Kinetics-400 的准确性,横轴显示不同模型、MViT 和并发 ViT [25] 视频变体的 FLOPs 的总体推理成本:VTN [78]、TimeSformer [6]、ViViT [1]。为了达到与 MViT 相似的精度水平,这些模型需要更多的计算和参数(例如,ViViT-L [1] 具有 6.8 倍更高的 FLOPs 和 8.5 倍更多的参数,在相同的精度下,§A.1 中有更多分析)并且需要大 -在 ImageNet-21K(包含比 Kinetics-400 多 60 倍的标签)上扩展外部预训练。

        我们进一步将我们的架构应用于 ImageNet [21] 上的图像分类任务,通过简单地删除通时间维度,并显示出比用于图像识别的单尺度 ViTs 的效果更好。

2. Related Work 

        卷积网络(ConvNets)。结合下采样、移位不变性和共享权重,ConvNets 是计算​​机视觉任务的事实上的标准主干。

        卷积网络中的自注意力。自注意力机制已用于图像理解、无监督对象识别以及视觉和语言。自注意力操作和卷积网络的混合也已应用于图像理解和视频识别。

        Vision Transformers。当前将 Transformers 应用于视觉任务的大部分热情始于 Vision Transformer (ViT) 和 Detection Transformer 。我们直接在ViT 的基础上构建了一个允许通道扩展和分辨率下采样的分阶段模型。 DeiT 提出了一种训练 ViT 的数据有效方法。我们的训练方法建立在相同设置下的 DeiT 之上,并将我们的图像分类模型与 DeiT 进行比较。

        一个新兴的工作想要将 transformers 应用于视觉任务,例如对象检测、语义分割、3D 重建、姿态估计、生成建模、图像检索、医学图像分割、点云、视频实例分割、对象重识别、视频检索、视频对话、视频对象检测和多模式任务 。一个单独的工作系列尝试使用学习的离散标记序列对视觉数据进行建模。

        Efficient Transformers。最近的工作降低了二次注意力的复杂性,使 transformers 对于自然语言处理应用程序更加高效,这是对我们方法的补充。

        三项并行工作提出了一种基于 ViT 的视频架构。然而,这些方法依赖于对大量外部数据的预训练,例如 ImageNet21K [21],因此使用 vanilla ViT [25] 进行最少的调整。相比之下,我们的 MViT 为转换器引入了多尺度特征层次结构,允许在没有大规模外部数据的情况下对密集视觉输入进行有效建模。

3. Multiscale Vision Transformer (MViT) 

        我们的通用 Multiscale Vision Transformer 架构建立在 stages 这个核心概念之上。每个 stage 都包含多个具有特定时空分辨率和通道维度的 transformer block。 Multiscale Transformers 的主要思想是逐步扩展通道容量,同时汇集网络从输入到输出的分辨率。

3.1. Multi Head Pooling Attention 

        我们首先描述多头池注意 (MHPA),这是一种自注意运算符,可在转换器块中实现灵活的分辨率建模,从而允许 MViT 在逐渐变化的时空分辨率下运行。与原始的多头注意 (MHA) 运算符相比,其中通道维度和时空分辨率保持固定,MHPA 池化潜在张量序列以减少参与输入的序列长度(分辨率)。图 3 显示了这个概念。 

         具体来说,考虑序列长度为 L 的 D 维输入张量 X,X \in \mathbb{R}^{L \times D}。根据 MHA,MHPA将输入 X 映射到中间 query  \hat{Q} \in \mathbb{R}^{L \times D},key \hat{K} \in \mathbb{R}^{L \times D} 和 value \hat{V} \in \mathbb{R}^{L \times D}

\hat{Q}=XW_{Q},\ \hat{K}=XW_{K},\ \hat{V}=XW_{V}\

其中 W_{Q},W_{K},W_{V}s 是D\times D 维。然后将获得的中间张量按序列长度合并,使用 pooling 运算符 p,如下所述。

        Pooling 运算符。在计算atten之前,中间张量 W_{Q},W_{K},W_{V} 使用运算符 P(\cdot ; \Theta) ,这是我们的 MHPA 的基石,通过扩展,我们的多尺度 Transformer 架构。

        运算符 P(\cdot ; \Theta) 沿每个维度对输入张量执行 pooling 计算。\Theta :=(k,s,p), 运算符使用维度为k^{T} \times k^{H} \times k^{W} 的池核 k,对应维度为 s^{T} \times s^{H} \times s^{W} 的步长 s 和对应维度为 p^{T} \times p^{H} \times p^{W} 的填充 p,通过与坐标应用方程式降低输入张量的维度 L = T\times H \times W 到 \widetilde{L}

\widetilde{L}=\left \lfloor \frac{L+2p-k}{s} \right \rfloor+1

 合并的张量再次展平,产生 P(Y ; \Theta ) \in \mathbb{R}^{\widetilde{L}\times D} 的输出,\widetilde{L}=\widetilde{T}\times \widetilde{H}\times \widetilde{W}

        默认情况下,我们在池化注意力运算符中使用具有保形填充 p 的重叠内核 k,因此输出张量 P(Y ; \Theta ) 的序列长度 \widetilde{L} 整体减少由 s^{T} \times s^{H} \times s^{W} 因素决定。

         Pooling Attention. 池化运算符 P(Y ; \Theta ) 独立地应用于所有中间张量 \hat{Q}\hat{K}\hat{V},并选择池化内核 k、步长 s 和填充 p。表示 θ 产生预注意向量 Q = P( \hat{Q}; \Theta_{Q})​​​​​​​, K = P( \hat{K}; \Theta_{K}) 和 V = P( \hat{V}; \Theta_{V}) 减少序列长度。现在在这些缩短的向量上计算注意力,通过操作

​​​​​​​Attention(Q,K,V)=Softmax(QK^{T}/\sqrt{D})V

 自然地,该操作会在池化算子上引入约束 s_{K}\equiv s_{V}。总之,集中注意力计算为,

Attention(Q,K,V)=Softmax( P(Q; \Theta_{Q}) P(K; \Theta_{K})^{T} / \sqrt{d}) P(V ;\Theta_{V})

 其中 \sqrt{d} 是按行对内积矩阵进行归一化。因此,随着 P(\cdot ) 中query Q 的缩短,Pooling attention 操作的输出的序列长度减少了 s_{T}^{Q} \times s_{H}^{Q} \times s_{W}^{Q} 的步幅因子。

        Multiple heads. 与 [98] 中一样,可以通过考虑 h 个头来并行化计算,其中每个头都在 D 维输入张量 X 的 D/h 通道的非重叠子集上执行集中注意力。

        Computational Analysis​​​​​​​. 由于注意力计算以二次方式缩放序列长度、pooling key、query 和 value 张量对多尺度 Transformer 模型的基本计算和内存要求具有显着的好处。由于序列长度缩减因子 f_{Q}f_{K}f_{V},我们有,

​​​​​​​f_{j}= s_{T}^{j}\cdot s_{H}^{j}\cdot s_{W}^{j}, \forall j\in \left \{Q,K,V \right \}

考虑到 P(\cdot ; \Theta) 的输入张量具有维度 D \times T \times H \times W,MHPA 的每个头运行时复杂度为 O(THWD/h(D+THW/f_{Q}f_{K})) ​​​​​​​和内存复杂度是 O(THWh(D/h +THW/f_{Q} f_{K} ))

         通道数 D 和序列长度项 THW/f_{Q} f_{K} 之间的这种权衡告知我们关于结构参数的设计选择,例如头数和层宽。我们建议读者参阅补充资料,了解有关时间记忆复杂性权衡的详细分析和讨论。

3.2. Multiscale Transformer Networks

        基于 Multi Head Pooling Attention(第 3.1 节),我们描述了专门使用 MHPA 和 MLP 层进行视觉表示学习的多尺度 Transformer 模型。首先,我们简要回顾了为我们的设计提供信息的 Vision Transformer 模型。

        Preliminaries: Vision Transformer (ViT). Vision Transformer (ViT) 架构 [25] 首先将分辨率为 T×H×W 的输入视频切成大小为 1×16×16 的非重叠块,其中 T 是帧数,H 是高度,W 是宽度每个 16 个,然后在扁平图像块上逐点应用线性层,以将它们投射到变换器的潜在维度 D 中。这相当于具有相同内核大小和 1×16×16 步长的卷积,在表 1 的模型定义中显示为 patch 1 阶段。

        接下来,将位置嵌入 E \in \mathbb{R}^{L\times D} 添加到每个长度为 L 维度为 D 的元素来编码位置信息并打破置换不变性。可学习的类嵌入被附加到投影的图像块中。

        得到的长度为 L + 1 的序列随后由一堆 N 个变换器块按顺序处理,每个变换器块执行注意力 (MHA [98])、多层感知器 (MLP) 和层归一化 (LN [3]) 操作。将 X 视为块的输入,单个变换器块的输出 Block(X) 由下式计算

X_{1}=MHA(LN(X))+X

Block(X)=MLP(LN(X_{1}))+X_{1}

N 个连续块后的结果序列被层归一化,类别嵌入被提取并通过线性层以预测所需的输出(例如类)。默认情况下,MLP 的隐藏维度是 4D。我们建议读者参考 [25,98] 了解详细信息。

        值得注意的是 ViT 在所有块中保持恒定的信道容量和空间分辨率(见表 1)。

Multiscale Vision Transformers (MViT). 我们的关键概念是逐步增加通道分辨率(即维度),同时降低整个网络的时空分辨率(即序列长度)。通过设计,我们的 MViT 架构在早期层具有精细的时空(和粗通道)分辨率,在后期层中上/下采样到粗时空(和精细通道)分辨率。 MViT 如表 2 所示。

        Scale stages. 一个尺度阶段被定义为一组 N 个变换器块,它们在相同的尺度上运行,具有相同的跨通道分辨率和时空维度 D \times T \times H \times W。在输入端(表 2 中的立方体 1),我们将补丁(或立方体,如果它们具有时间范围)投影到较小的通道维度(例如,比典型的 ViT 模型小 8 倍),但长序列(例如比典型的 ViT 模型密度高 16 倍;参见表 1)。

         在阶段转换时(例如表 2 中的 scale1 到 scale2),已处理序列的通道维度被上采样,而序列的长度被下采样。这有效地降低了底层视觉数据的时空分辨率,同时允许网络将处理后的信息吸收到更复杂的特征中。

         Channel expansion. 当从一个阶段过渡到下一个阶段时,我们通过将前一阶段中最终 MLP 层的输出增加一个与该阶段引入的分辨率变化相关的因子来扩展通道维度。具体来说,如果我们将时空分辨率下采样 4 倍,我们就会将通道维度增加 2 倍。例如,scale 3 到 scale 4 将分辨率从 2D\times \frac{T}{s_{T}} \times \frac{H}{8} \times \frac{T}{8}  更改为表 2 中的 4D\times \frac{T}{s_{T}} \times \frac{H}{16} \times \frac{T}{16}。这大致保留了跨阶段的计算复杂度,类似于 ConvNet设计原则 [87,45]。

        Query pooling. 集中注意操作不仅在 key 和 value 向量的长度方面而且在 query 的长度方面提供了灵活性,从而在输出序列方面提供了灵活性。对 query 使用 P(Q; k; p; s) 其中内核为 s \equiv (s_{T}^{Q} , s_{H}^{Q} , s_{W}^{Q} ) 会让序列缩短 s_{T}^{Q} \cdot s_{H}^{Q} \cdot s_{W}^{Q} 。因为,我们的目的是在一个阶段的开始降低分辨率,然后在整个阶段保持这个分辨率,所以只有每个阶段的第一个池注意力操作符以非退化查询步幅 s^{Q}> 1 运行,所有其他操作符都被限制为s^{Q} \equiv (1, 1, 1)

        Key-Value pooling. 与 Query pooling 不同,改变 key 和 value 张量的序列长度,不会改变输出序列长度,因此不会改变时空分辨率。然而,它们在池化注意力算子的整体计算要求中起着关键作用。

        我们解耦了 K、V 和 Q 池化的使用,Q 池化用于每个阶段的第一层,key 和 value池化用于所有其他层。由于键和值张量的序列长度需要相同才能计算注意力权重,因此用于 key 和 value 张量的池化步长需要相同。在我们的默认设置中,我们将所有池化参数 (k; p; s) 约束为相同,即 Θ K ≡ Θ V 在一个阶段内,但自适应地改变 s w.r.t.到跨阶段的规模。

        Skip connections.由于通道维度和序列长度在残差块内发生变化,我们池化跳跃连接以适应其两端之间的维度不匹配。 MHPA 通过将 query 池运算符 P(\cdot ; \Theta_{Q})  添加到剩余路径来处理这种不匹配。如图 3 所示,我们不是直接将 MHPA 的输入 X 添加到输出,而是将池化的输入 X 添加到输出,从而将分辨率与参与 query 相匹配。

        为了处理阶段变化之间的通道尺寸不匹配,我们采用了一个额外的线性层,该层对我们的 MHPA 操作的层归一化输出进行操作。请注意,这与对非标准化信号进行操作的其他(保留分辨率)跳过连接不同。

3.3. Network instantiation details

        表 3 显示了 Vision Transformers [25] 和我们的 MViTs 基本模型的具体实例。 ViT-Base [25](表 3b)最初将输入投射到形状为 1×16×16 且尺寸为 D = 768 的贴片,然后堆叠 N = 12 个transformer块。对于 8×224×224 输入,所有层的分辨率固定为 768×8×14×14。序列长度(时空分辨率类标记)为 8·14·14 1 = 1569。

        我们的 MViT-Base(表 3b)由 4 个缩放阶段组成,每个阶段都有几个通道尺寸一致的块。 MViT-B 最初将输入投影到 D = 96 的通道维度,具有形状为 3×7×7 的重叠时空立方体。对于每个附加阶段,长度为 8∗56∗56 1 = 25089 的结果序列减少了 4 倍,最终序列长度为 8 ∗ 7 ∗ 7 1 = 393 在尺度 4 。同时,通道维度在每个阶段被上采样 2 倍,在比例为 4 时增加到 768。请注意,所有池化操作以及分辨率下采样仅在数据序列上执行,而不涉及已处理的类标记嵌入。

        我们在 scale1 阶段将 MHPA 头数设置为 h = 1,并随着通道尺寸增加头数(每头 D/h 的通道数保持在 96)。

在每个阶段转换中,前一阶段的输出 MLP 维度增加 2 倍,MHPA 池化在 Q 张量上,其中 s^{Q} = (1, 2, 2) 在下一阶段的输入。

        我们在所有 MHPA 块中使用 K、V 池化,其中 Θ K ≡ Θ V 和 s^{Q} = (1, 8, 8) 在 scale1 中自适应地衰减这个步幅,采用相同的跨阶段的比例,使得 K、V 张量在所有块中具有一致的比例。

猜你喜欢

转载自blog.csdn.net/like_jmo/article/details/127908408