【计算机视觉】ViT:Vision Transformer 讲解

有任何的书写错误、排版错误、概念错误等,希望大家包含指正。

在阅读本篇之前建议先学习:
【自然语言处理】Attention 讲解
【自然语言处理】Transformer 讲解
【自然语言处理】BERT 讲解

ViT : Vision Transformer

1. 模型概述

Transformer 已经在自然语言处理(NLP)领域中取得了显著效果,但是在计算机视觉(CV)领域的应用非常有限。在计算机视觉领域,对 Transformer 的应用主要体现在两类方法中:其一是将注意力与卷积神经网络结合;其二是用注意力层替换某些卷积层。显然,这两类方法本质上并不是 Transformer 架构,都没有改变对 CNNs 架构的依赖。

本文作者提出了 Vision Transformer(ViT)模型用于图像分类任务。ViT 模型结构的灵感是,尽可能不改变原始的 Transformer 结构;ViT 模型通过在大型数据集上有监督地预训练,在任务相关的小型数据集上微调的方式进行训练;ViT 模型得出的重要结论是,当拥有充足的数据对其进行预训练时,ViT 的效果会超过 CNNs,突破 Transformer 结构相比于 CNNs 缺少归纳偏置的限制,可以在下游任务中获得较好的迁移效果。ViT 的提出是 Transformer 应用于计算机视觉领域的里程碑,同时引爆了后续关于 Transformer 应用于其他视觉任务的研究,比如目标检测(object detection)和分割(segment),推动了 NLP 和 CV 领域的大一统。

2. 模型结构

在这里插入图片描述

图 1    ViT 模型

2.1. 输入序列的转化

对图像应用 Transformer 的遇到的第一个问题是如何将图像转换为 Transformer 的输入序列。在进行自注意力计算时,需要计算每个输入序列元素与其他元素的注意力权重,计算成本为输入序列长度的平方。如果直接以像素为单位将图像对应的二维张量展平成一维张量,那么大小为 224 × 224 224\times 224 224×224 的图像将转换为长度为 50176 50176 50176 输入序列。我们知道 BERT 的输入序列长度才不过几百上千,显然计算机无法处理长度为 50176 50176 50176 的输入序列,更不要说处理分辨率为 256 × 256 256\times 256 256×256 320 × 320 320\times 320 320×320 的图像。

已经提出的两个有代表性的处理方式为,第一种是将图像经过 CNNs 的中间特征图展开后作为 Transformer 的输入序列,第二种的思想是放弃对整张图进行注意力操作,而是在局部小窗口内进行操作,极端情况下,可以仅沿某个维度进行注意力操作。虽然第二种在理论上是有效的,但是注意力操作比较特殊,现在的硬件不支持加速这类操作,因此实际执行速度不佳。

本文作者采用将图像分割成大小一致的 patch,每个 patch 作为输入序列元素。对于大小为 224 × 224 224\times 224 224×224 的图像,如果规定 p a t c h _ s i z e = 16 \rm patch\_size=16 patch_size=16,即每个 patch 对应的向量长度为 16 × 16 = 256 16\times 16 = 256 16×16=256,那么图像将被分割为 14 × 14 14\times 14 14×14 个 patch,对应的输入序列长度为 196 196 196,这在计算机的处理范围内。如果是三通道图像,每个 patch 对应的向量长度变为 16 × 16 × 3 = 768 16\times 16\times 3 = 768 16×16×3=768,输入序列的长度不变。

2.2. 类别记号与位置编码

受 BERT 启发,作者在输入序列前引入了 <CLS> 记号,用于最后的图像分类,意义与 BERT 中的记号类似。将 patch 序列输入到线性映射层得到每个元素的 Embedding,与对应位置编码的加和作为 Transformer Encoder 的输入,这与标准 Transformer 对输入的处理相同。作者通过实验发现采用一维位置编码和采用二维位置编码带来的效果非常接近,原因可能是 ViT 的作用对象是 patch 而不是像素,对于网络而言 patch 的相对位置关系比较容易理解,因此编码方式对结果影响不大。如图 2 2 2 所示。

在这里插入图片描述

图 2    在 ImageNet 数据集上用 Few-shot 评估不同位置编码对模型效果的影响

同样是受 BERT 启发,ViT 先采用大型通用数据集对模型进行有监督预训练,再使用与具体任务相关的小型数据集对模型微调。对于图像分类任务而言,在预训练阶段,将 Transformer Encoder 关于 <CLS> 标记的输出向量输入到 MLP 中用于分类;与预训练阶段不同,微调阶段直接将 <CLS> 标记的输出向量映射至图像类别数以实现分类。这是因为作者通过实验发现,将预训练后的 ViT 迁移至另外的数据集上时,零初始化的单一线性映射层比重新初始化的 MLP 更加健壮。

另外,作者对比了用 <CLS> 标记进行分类和用全局平均池化(Global Average Pooling,GAP)在效果上的差异。ViT 中全局平均池化的具体操作流程为,将每个序列元素对应的输出向量对应位置计算平均值,最终得到的还是一个长度不变的向量,比如三个序列元素对应的输出向量分别为 [1, 2, 3][1, 3, 5][2, 4, 6],全局平均池化的结果为 [(1+1+2)/3, (2+3+4)/3, (3+5+6)/3] = [1.3333, 3, 4.6667]。实验结果表明,只要参数合适,使用 <CLS> 和使用 GAP 的效果接近。如图 3 3 3 所示。

在这里插入图片描述

图 3    采用 CLS 和全局平均池化的对比。当不要求学习率相同时,二者的最佳效果接近

2.3. 模型变体

作者设计了三个不同规模的 ViT 模型,ViT-Base、ViT-Large 和 ViT-Huge。三者在层数、Transformer 单元输入输出的向量长度、MLP 隐藏层神经元个数、多头注意力的头数以及总参数量上有所不同,具体差异如下图 4 4 4 所示。
在这里插入图片描述

图 4    不同规模的 ViT 模型细节

Layers 表示 Transformer Encoder 中编码器模块的堆叠层数;Hidden size D 与 Transformer 论文中的 d m o d e l d_{\rm model} dmodel 含义相同,表示每个编码器的输入向量维度和输出向量维度;MLP size 表示 Transformer 中每个编码器的 FFN 子模块中隐藏层神经元个数;Heads 表示 Transformer 中每个编码器的 MSA(Multi-head Self Attention)子模块中头的数量。

注意,ViT 模型只使用了 Transformer 的编码器部分(变形后),不包括 Transformer 解码器部分。

2.4. 计算流程

1 1 1 描述了模型执行流程。标准 Transformer 只接收一维序列的 token embedding,为了处理二维图像,我们将图像 $ \textbf{x}∈\mathbb R^{H\times W \times C}$ 重塑为由一维展开后的 patch 构成的序列 x p ∈ R N × ( P 2 ⋅ C ) \textbf{x}_p\in \mathbb R^{N\times (P^2·C)} xpRN×(P2C),其中 ( H , W ) (H, W) (H,W) 为原始图像的分辨率, C C C 为通道数, ( P , P ) (P,P) (P,P) 为每个 patch 的分辨率, N = H W / P 2 N = HW/ P^2 N=HW/P2 表示序列 x p \textbf x_p xp 的长度,亦 patch 的数量。规定Transformer 全部子层的输入和输出向量维度均为 D D D,因此引入一个可训练的线性投影 ( E q . 1 ) (\rm Eq. 1) (Eq.1) 将 patch 映射到 D D D 维。将这个投影的输出称为 patch embedding。与 BERT 模型的 <CLS> 标记类似,我们在 patch embedding 序列前添加一个可学习的嵌入 z 0 0 = x c l a s s \textbf{z}_0^0=\textbf{x}_{\rm class} z00=xclass,其在 Transformer Encoder z L 0 \textbf{z}_{L}^0 zL0 处的输出向量作为整个图像的表征 y y y ( E q . 4 ) (\rm Eq. 4) (Eq.4)。在预训练和微调阶段,最终的分类映射层接收 z L 0 \textbf{z}_{L}^0 zL0。分类映射层在预训练时由一个单隐层 MLP 实现,在微调时由一个线性层实现。位置编码被添加到 patch embedding 序列中以保留位置信息。我们使用可学习的一维位置编码,因为没有实验显示使用二维位置会有显著的性能提升。最终的 embedding 序列作为编码器的输入。

出于实际效果的考虑,ViT 中的 Transformer Encoder 部分与标准 Transformer 的编码器部分并非完全相同。而且基于 Transformer 思想的后续论文大多采用 Transformer Encoder 的变体,而非标准 Transformer Encoder。变体仍由多头自注意力模块(MSA)和 MLP 模块交替堆叠而成 ( E q . 2 , 3 ) (\rm Eq. 2,3) (Eq.2,3)。每个模块之前为 LayerNorm(LN)层,每个模块的输出还应进行残差链接。计算流程可大致描述为:
z 0 = [ x c l a s s ; x p 1 E ; x p 2 E ; . . . ; x p N E ] + E p o s , E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D z l ′ = M S A ( L N ( z l − 1 ) ) + z l − 1 , l = 1... L z l = M S A ( L N ( z l ′ ) ) + z l ′ , l = 1... L y = L N ( z L 0 ) \begin{align} &\textbf{z}_0 = [\textbf {x}_{\rm class}; \textbf{x}_p^1 \textbf E;\textbf{x}_p^2 \textbf E;...;\textbf{x}_p^N \textbf E] +\textbf E_{pos}, &&\textbf E \in \mathbb R^{(P^2·C)\times D}, \textbf E_{pos}\in \mathbb R^{(N+1)\times D}\tag{1}\\ &\textbf {z}'_l = {\rm MSA}({\rm LN}(\textbf z_{l-1})) + \textbf z_{l-1}, && l = 1 ... L\tag{2}\\ &\textbf {z}_l = {\rm MSA}({\rm LN}(\textbf z'_{l})) + \textbf z'_{l}, && l = 1 ... L\tag{3}\\ &\textbf y = {\rm LN}(\textbf z_L^0) \tag{4} \end{align} z0=[xclass;xp1E;xp2E;...;xpNE]+Epos,zl=MSA(LN(zl1))+zl1,zl=MSA(LN(zl))+zl,y=LN(zL0)ER(P2C)×D,EposR(N+1)×Dl=1...Ll=1...L(1)(2)(3)(4)

参考源码可以得到更具体的实现细节如图 5 5 5 所示。

在这里插入图片描述

图 5    从左到右依次为 MLP block、Encoder block、Trm Encoder 和 ViT 模块细节

其中,Multi-head Dot Product Attention 模块是标准的 Transformer 中的多头注意力模块。具体地,由输入序列 z ∈ R N × D \textbf z\in \mathbb R^{N\times D} zRN×D 得到 q \textbf q q k \textbf k k v \textbf v v,注意力权重 A i j A_{ij} Aij q i \textbf q^i qi k j \textbf k^j kj 确定,表示序列中两个元素的相似性,以 A i j A_{ij} Aij 为权重对 v \textbf v v 加和,对应公式如下:
[ q , k , v ] = z U q k v U q k v ∈ R D × 3 D h A = s o f t m a x ( qk T / D h ) A ∈ R N × N S A ( z ) = A v \begin{align} &[\textbf q, \textbf k, \textbf v] = \textbf z \textbf U_{qkv} &\textbf U_{qkv}\in \mathbb R^{D\times 3D_h} \tag{5}\\ &A = {\rm softmax}(\textbf{qk}^{\rm T}/\sqrt{D_h}) &A\in \mathbb R^{N\times N} \tag{6} \\ &{\rm SA}(\textbf z) = A\textbf v \tag{7} \end{align} [q,k,v]=zUqkvA=softmax(qkT/Dh )SA(z)=AvUqkvRD×3DhARN×N(5)(6)(7)
多头注意力(MSA)是对 SA 的扩展。并行执行 k k k 个自注意操作,并对它们拼接后的输出进行线性映射。为了在改变 k k k 时保持计算量和参数数量不变, D h D_h Dh ( E q . 5 ) \rm(Eq. 5) (Eq.5) 通常设置为 D / k D/k D/k ,对应公式如下:
M S A ( z ) = [ S A 1 ( z ) ; S A 2 ( z ) ; . . . ; S A k ( z ) ] U m s a                                U m s a ∈ R k ⋅ D h × D (8) MSA(\textbf{z}) = [{\rm SA}_1(\textbf z);{\rm SA}_2(\textbf z);...;{\rm SA}_k(\textbf z)]\textbf U_{msa}\;\;\; \;\;\;\;\;\;\;\;\;\;\;\;\textbf U_{msa}\in \mathbb R^{k·D_h\times D} \tag{8} MSA(z)=[SA1(z);SA2(z);...;SAk(z)]UmsaUmsaRkDh×D(8)
AddPostionEmbs 表示将 patch 对应像素映射向量与位置编码相加;IdentityLayer 表示恒等映射,即 I d e n t i t y ( x ) = x \rm Identity(\textbf x) = \textbf x Identity(x)=x

在 ViT 模块中,预训练阶段和微调阶段对应操作模块的细微不同在图中也有体现。

另外,微调阶段输入图像的分辨率可能会大于预训练阶段的图像分辨率,仅 Transformer 的内部结构的话,完全可以处理任意长度序列,因为 Transformer 内部的自注意力模块是对全局信息的捕捉。考虑 Transformer 的输入,包括 patch embedding 和位置编码两部分,不同的序列长度不会影响每个位置 patch embedding 的生成,因为它们经过同一层线性映射;但是每个序列元素都应该有一个独一无二的位置编码,如果微调阶段的序列长度大于预训练阶段的序列长度,那么微调时的输入序列可能存在某些元素没有对应的预训练好的位置编码。作者采用二维插值的方式处理这种特殊情况。

3. 实验概述

在预训练阶段使用了三个数据集,ImageNet-1K(1.3million)、ImageNet-21K(14million)和 JFT-18K(303million),同时删除了这三个数据集中与下游测试集相同的重复样本。其中,下游测试集包括,ImageNet(on the original validation labels),ImageNet (on the cleaned-up ReaL labels ),CIFAR-10/100,Oxford-IIIT Pets,Oxford Flowers-102,VTAB (19 tasks)。

下面实验中提到的 CNNs 是指 BiT 模型或者其变体。BiT 模型与 ViT 模型都是来自 Google 的研究团队。见参考 [2]。

在这里插入图片描述

图 6    模型在不同大小数据集上预训练后在 ImageNet 上的效果对比(左)和不优化正则化超参数时模型在不同大小数据集上预训练后通过线性 Few-shot 评估效果对比

作者对比了 ViT 模型与 CNNs 模型(这里用的是 BiT 模型作对比)在不同大小的数据集上预训练后迁移至 ImageNet 数据集上进行微调的效果。实验显示,当在小型数据集上预训练时,ViT 微调后的效果远远比不上 CNNs;当在中型数据集上预训练时,二者效果相当;当在大型数据集上预训练时,ViT 性能超越 CNNs。如图 6 6 6 左所示。

作者又对比了在不优化三个正则化超参数(weight decay、dropout 和 label smoothing)的前提下,ViT 模型与 CNNs 模型在不同大小的数据集上预训练后的效果。为了节省计算资源,作者选择不进行微调而是将模型作为特征提取器,直接对输出特征进行逻辑回归实现分类。评估效果时采用 Few-shot,即每一类只采用五张图像用于评估。实验结果表明,卷积神经网络的归纳偏置对于较小的数据集是有用的,但对于较大的数据集,直接从数据中学习完全可行的,甚至是更有效的。如图 6 6 6 右所示。

至此,作者证明了本篇论文的关键想法:当预训练数据量充足时,ViT 模型能够摆脱对归纳偏置的依赖,甚至可以得到更优质的效果。

卷积神经网络的归纳偏置(inductive bias)可以认为是卷积神经网络自带的先验知识,或者是对目标函数的启发式假设,对于模型处理不同的任务起到正向作用。

卷积神经网络的归纳偏置有两个:局部相关性(locality)和平移不变性(translation equivariance)。局部相关性是指假设图像上的相邻位置具有相关性;平移不变性是指无论先进行卷积操作还是先进行窗口滑动,结果都是一样的,通俗来说,图像中的同一个物体无论移动到什么位置,只要是输入相同,那么遇到同样卷积核的输出必然相同。

有了这两个归纳偏置后,卷积模型就具备了很多先验知识,所以只需要相对少的数据就可以学习到比较好的模型,但是 Transformer 不具有这些先验知识,因此其对视觉世界的感知全部需要从数据中自己学习。上面的实验就是在验证这个假设。

从图 6 6 6 中我们可以看出 ViT 的效果没有远远超出 CNNs。事实确实如此,作者在 JFT-300M 数据集上对模型进行预训练,实验发现甚至是较小的 ViT 模型,在所有数据集上的微调效果都优于 CNNs,但是无论是较大的 ViT 模型还是较小的 ViT 模型在效果上与 CNNs 仅差不到百分之一个点,如图 7 7 7 所示。这显然无法突显 ViT 模型的优势,因此作者在图 7 7 7 的最后一行列出了模型训练时长。对比时长发现,ViT 模型在大幅降低计算所需资源的情况下依然能保证效果不比 SOTA(State Of The Art)模型差,这使得 ViT 模型不至于没有亮点。

图 7    对比 SOTA 模型与 ViT 模型分类效果

为了进一步证明 ViT 模型的预训练过程具有资源消耗少的优势,作者在这部分实验中,除了对比 ViT 和 CNNs 外,还引入了 Hybrid(混合)模型作对比。Hybrid 模型是指以 ResNet 提取的特征图作为 ViT 模型的输入,对应于代码中, E \textbf E E ( E q . 1 ) (\rm Eq. 1) (Eq.1) 的映射对象是来自 ResNet 的特征图。

实验的主要内容是,通过 JFT-300M 数据集对三类模型(CNNs、ViT 和 Hybrid)进行预训练,在 ImageNet、ImageNet ReaL、CIFAR-10、CIFAR-100、Pets 和 Flowers 数据集上进行效果评估。具体来说,模型包括:7 个 ResNet 模型,预训练 7 个 epoch 的 R50x1、R50x2、R101x1、R152x1 和 R152x2,以及预训练 14 个 epoch 的 R152x2 和 R200x3;6 个 ViT 模型,预训练 7 个 epoch 的 ViT-B/32、ViT-B/16、ViT-L/32 和 ViT-L/16,以及预训练 14 个 epoch 的 ViT-L/16 和 ViT-H/14;5 个 Hybrid 模型,预训练 7 个 epoch 的 R50+ViT-B/32、R50+ViT-B/16、R50+ViT-L/32 和 R50+ViT-L/16,以及预训练 14 个 epoch 的 R50+ViT-L/16。另外,由于 ImageNet 数据集比较重要,所以作者在实验中专门单独利用 ImageNet 数据集进行评估,而对于其他的五个数据集则取平均值作为估计结果,两部分实验在保证同资源消耗的前提下比较评估结果,分别如图 8 8 8 右和左所示。

图 8    ViT、CNNs 和 Hybrid 模型同资源消耗下的性能对比

另外,作者还使用类似 BERT 中完形填空的方法进行实验,即预训练阶段无需图像标签,但是效果不理想。

还有一些其他实验,这里不再讲解。

REF

[1] Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020.

[2] Kolesnikov A, Beyer L, Zhai X, et al. Big transfer (bit): General visual representation learning[C]//Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part V 16. Springer International Publishing, 2020: 491-507.

[3] ViT论文逐段精读【论文精读】- bilibili

[4] 【自然语言处理】Transformer 讲解 - CSDN

[5] 【自然语言处理】BERT 讲解 - CSDN博客

[6] 代码:vit_jax/models.py · master · mirrors / google-research / vision_transformer · GitCode

[7] 【transformer】ViT_transformer vit - CSDN

[8] ViT(Vision Transformer)解析 - 知乎

猜你喜欢

转载自blog.csdn.net/weixin_46221946/article/details/129639904