ICML2021《Training data-efficient image transformers & distillation through attention》

在这里插入图片描述
论文链接:http://proceedings.mlr.press/v139/touvron21a/touvron21a.pdf
代码链接:https://github.com/facebookresearch/deit

1. 动机

VIT训练需要消耗大量的计算资源,且训练时间长。此外,当没有充足数据时很难泛化

2. 贡献

  • 作者证明,不包含卷积层的神经网络可以在没有外部数据的情况下,在ImageNet上获得与当前技术水平相比具有竞争力的结果。且它们是在4个gpu的单个节点上学习的,需要3天时间。本文的两个新模型DeiT-S和DeiT-Ti参数更少,可以看作是ResNet-50和ResNet-18的对等物。
  • 引入了一种基于蒸馏token的新蒸馏过程,它的作用与类token相同,只是它的目的是再现教师估计的标签。这两个token通过注意在Transformer中交互。这种Transformer专用策略比传统蒸馏的效果好得多。
  • 作者在Imagenet上预先学习的模型在转移到不同的下游任务时是有竞争力的,例如细粒度分类,在几个流行的公共基准上:CIFAR-10、CIFAR-100、Oxford-102 flowers、Stanford Cars和inaturalist18 /19。

3. 方法

3.1 视觉Transformer

  • Multi-head Self Attention layers (MSA)
    注意力机制是基于一个可训练的(key,value)向量对的联合记忆。使用内积将query向量 q ∈ R d q \in \mathbb{R}^d qRd与一组 k k k个key向量(打包成一个矩阵 K ∈ R k × d K \in \mathbb{R}^{k \times d} KRk×d)进行匹配。然后,用softmax函数对这些内积进行缩放和归一化,以获得 k k k个权重。注意力的输出是一组 k k k值向量的加权和(打包成 V ∈ R k × d V \in \mathbb{R}^{k \times d} VRk×d)。对于 N N N个query向量序列(打包成 Q ∈ R N × d Q \in \mathbb{R}^{N \times d} QRN×d),它产生一个输出矩阵(大小为 N × d N \times d N×d):
    在这里插入图片描述
    其中Softmax函数应用于输入矩阵的每一行, d \sqrt{d} d 提供了适当的标准化。Vaswani et al.(2017)提出了自我注意力层。query,key和value矩阵本身计算从 N N N个输入向量序列(塞进 X ∈ R N × D X \in \mathbb{R}^{N \times D} XRN×D): Q = X W Q Q = XW_Q Q=XWQ K = X W K K = XW_K K=XWK, V = X W V V = XW_V V=XWV,使用拥有约束 k = N k = N k=N的线性变换 W Q W_Q WQ W K W_K WK W V W_V WV,这意味着attention是在所有的输入向量之间。
    最后,通过考虑 h h h个注意头,即 h h h个自注意函数,定义了多头自注意力层(MSA)。每个头提供一个大小为 N × d N \times d N×d的序列。这些 h h h序列被重新排列成 N × d h N \times dh N×dh序列,然后由线性层重新投影成 N × D N \times D N×D

  • Transformer block for images
    为了得到一个完整的Transformer块(Vaswani等人,2017),作者在MSA层上添加了一个前馈网络(FFN)。这个FFN由两个线性层组成,被GeLu激活隔开(Hendrycks &Gimpel, 2016)。第一个线性层将维数从 D D D扩展到 4 D 4D 4D,第二层将维数从 4 D 4D 4D还原到 D D D。由于跳跃连接,MSA和FFN都作为残差算子运行,并进行了层归一化(Ba et al., 2016)。
    为了获得一个Transformer来处理图像,本文工作建立在ViT模型之上(Dosovitskiy等人,2020年)。它是一种简单而优雅的体系结构,它处理输入图像就像处理输入标记序列一样。将固定大小的输入RGB图像分解为一批 N N N个固定大小为 16 × 16 16 \times16 16×16像素( N = 14 × 14 N = 14 \times 14 N=14×14)的patch。每个patch用一个线性层进行投影,该层保持其总体维度 3 × 16 × 16 = 768 3 \times 16 \times 16 = 768 3×16×16=768。上述Transformer块是不变的patch嵌入顺序,因此忽略了他们的位置。位置信息被合并为固定的(Vaswani et al., 2017)或可训练的(Gehring et al., 2017)位置嵌入。它们被添加到patch token的第一个Transformer块之前,然后将其提供给Transformer块堆栈。

  • The class token
    类token是一个可训练的向量,被附加到第一层之前的patch token中,它经过转换层,然后与线性层一起投影来预测类。这个类token继承自NLP (Devlin et al., 2018),并与计算机视觉中用于预测类的典型池化层不同。因此,Transformer处理维 D D D ( N + 1 ) (N + 1) (N+1)token批次,其中只有类向量用于预测输出。这种体系结构迫使自注意在patch token和类token之间传播信息:在训练时,监督信号只来自类嵌入,而patch token是模型的唯一变量输入。

  • 固定跨分辨率的位置编码
    Touvron等人(2019)表明,使用较低的训练分辨率并在较大的分辨率下微调网络是可取的。在现有的数据增强方案下,这加快了完整的训练,并提高了准确性。当增加一个输入图像的分辨率时,我们保持patch的大小不变,因此输入的 N N N个patch的数量会发生变化。由于Transformer块和类token的体系结构,不需要修改模型和分类器来处理更多token。相反,需要适应位置嵌入,因为有 N N N个位置嵌入,每个patch对应一个位置嵌入。Dosovitskiy等人(2020)在改变分辨率时插入位置编码,并证明该方法适用于随后的微调阶段。

3.2 Distillation through attention

在本节中,作者假设可以使用一个强图像分类器作为教师模型。它可以是卷积神经网络,也可以是分类器的混合体。然后利用这位老师来解决如何学习Transformer的问题。如下表格通过比较精度和图像吞吐量之间的权衡所看到的,用Transformer替代卷积神经网络是有益的。这里涵盖蒸馏的两个方向:硬蒸馏与软蒸馏,经典蒸馏与蒸馏token。
在这里插入图片描述

  • Soft distillation
    一些工作最小化教师模型的softmax和学生模型的softmax之间的Kullback-Leibler分歧。设 Z t Z_t Zt为教师模型的对数, Z s Z_s Zs为学生模型的对数,用 τ \tau τ表示蒸馏温度, λ \lambda λ平衡Kullback Leibler散度损失(KL)和交叉熵(LCE)的系数:
    在这里插入图片描述

  • Hard-label distillation
    这里引入了一种蒸馏的变体,把老师的hard决定作为一个真正的标签。令 y t = a r g m a x c Z t ( c ) y_t = argmax_cZ_t(c) yt=argmaxcZt(c)是教师的hard决定,与此Hard-label蒸馏相关的目标是:
    在这里插入图片描述
    对于给定的图像,与教师相关的Hard-label可能会根据具体的数据扩充而改变。我们将看到,这个选择比传统的选择更好,同时无参数,概念更简单:教师预测 y t y_t yt与真实标签 y y y发挥相同的作用.

  • Label smoothing
    hard-label也可以通过标签平滑转换为软标签(Szegedy等人,2016),其中真正的标签被认为有1个 ε \varepsilon ε的概率,其余的 ε \varepsilon ε在其余的类中共享。在所有使用真标签的实验中,都将 ε = 0.1 \varepsilon= 0.1 ε=0.1。请注意,这里不会平滑教师提供的伪标签(例如,在hard蒸馏中)。

  • Distillation token
    如图2所示,作者向初始嵌入(patch和类token)添加一个新的token,即蒸馏token。这里的蒸馏token类似于类token的使用:它通过自我注意与其他嵌入进行交互,最后一层之后由网络输出。它的目标是由损失的蒸馏成分给出的。蒸馏嵌入允许我们的模型从教师的输出中学习,就像在常规的蒸馏中一样,同时与类嵌入保持互补。
    在这里插入图片描述

  • Fine-tuning with distillation
    作者使用真实标签和教师预测在微调阶段在更高的分辨率。使用具有相同目标分辨率的教师,通常是通过Touvron等人(2019)的方法从低分辨率的教师那里获得的。这里也只测试了真正的标签,但这降低了教师的利益,并导致较低的表现。

  • Classification with our approach: joint classifiers
    在测试时,由Transformer产生的类或蒸馏嵌入都与线性分类器相关联,并能够推断出图像标签。本文的参考方法是对这两个单独的头进行后期融合,然后将两个分类器的softmax输出相加,进行预测。如下表评估这三个选项。
    提出的策略进一步提高了性能,表明两个token提供了互补的有用的分类信息:两个token上的分类器明显优于独立类和蒸馏分类器,它们本身已经优于蒸馏基线。与蒸馏token相关联的嵌入比类token的嵌入效果稍好。它与卷积神经网络预测的相关性也更强。在所有情况下,包含它都会提高不同分类器的性能。
    在这里插入图片描述

3.3 DeiT架构三种变体

在这里插入图片描述

4. 部分实验结果

4.1 不同教师架构下的蒸馏结果

正如Abnar等人(2020)所解释的那样,卷积神经网络是一个更好的老师,这可能是因为Transformer通过蒸馏继承了归纳偏见。在本文随后的所有蒸馏实验中,默认的教师网络是一个RegNetY-16GF (Radosavovic等人,2020年),具有84M参数,使用与DeiT相同的数据和相同的数据扩充进行训练。这个教师网络在ImageNet上的准确率达到了82.9%的第一名
在这里插入图片描述

4.2 蒸馏方法的对比

如上表3中比较了不同蒸馏策略的性能。对于Transformer,硬蒸馏显著优于软蒸馏,即使只使用类token:在分辨率为 224 × 224 224 \times 224 224×224时,硬蒸馏达到83.0%,而软蒸馏精度为81.8%。
在这里插入图片描述

4.3 赞同教师的观点&归纳偏见?

如上所述,教师的架构有着重要的影响。它是否继承了现有的有利于训练的归纳偏见? 虽然我们认为很难正式地回答这个问题,但我们在表4中分析了convnet教师、本文的图像Transformer DeiT只从标签中学习,以及DeiT之间的决策协议。与从零开始学习的Transformer相比,本文的蒸馏模型与卷积神经网络更相关。正如预期的那样,与蒸馏嵌入相关联的分类器比与类嵌入相关联的分类器更接近卷积网络,而与类嵌入相关联的分类器则更类似于没有蒸馏学习的DeiT。不出所料,联合类+蒸馏分类器提供了一个中间地带。
在这里插入图片描述

4.4 token分析

学习的类和蒸馏token向不同的向量收敛:这些token之间的平均余弦相似度(cos)等于0.06。每一层计算的类嵌入和蒸馏嵌入通过网络逐渐变得更加相似,一直到相似度高的最后一层(cos=0.93),但仍然小于1。这是可以预料的,因为它们的目标是产生相似但不相同的目标。
与简单地添加与同一目标标签相关的额外类token相比,本文验证了蒸馏token向模型中添加了一些东西:作者使用一个带有两个类token的Transformer,而不是教师伪标签。即使我们随机独立地初始化它们,在训练过程中它们收敛于同一个向量(cos=0.999),输出嵌入也准相同。与我们的蒸馏策略相反,额外的类token并不会对分类性能带来任何影响。

4.5 迁移学习到下游任务

尽管DeiT在ImageNet上表现得很好,但为了衡量DeiT的泛化能力,在其他数据集上评估它们的迁移学习是很重要的。通过对表8中的数据集进行微调,作者在迁移学习任务中评估了这一点。表6将DeiT迁移学习结果与ViT和EfficientNet进行了比较。DeiT与有竞争力的convnet模型相当,这与之前在ImageNet1k上得出的结论一致。
在这里插入图片描述
在这里插入图片描述

4.6 数据增强

与集成更多先验(如卷积)的模型相比,transformer需要更多的数据。因此,为了用相同大小的数据集进行训练,本文依赖于广泛的数据扩充。作者评估不同类型的强数据增强,以达到数据高效的训练方法。同时,考虑了不同的优化器,并交叉验证了不同的学习速率和重量衰减。transformer对优化超参数的设置非常敏感。
在这里插入图片描述

4.7 训练时间

对于DeiT-B来说,300个epoch的典型训练需要37小时(2个节点)或53小时(单个8-GPU节点)。作为对比,使用RegNetY-16GF(84M参数)进行类似的训练要慢20%。DeiT-S和DeiT-Ti在4个GPU上训练不到3天。然后,我们可以在更大的分辨率上微调模型。这需要在8个gpu上花20个小时来调整分辨率为 384 × 384 384 \times 384 384×384的DeiT-B模型,这相当于25个epoch。不需要依赖于批处理规范,可以在不影响性能的情况下减少批处理大小,这使得训练更大的模型更容易。请注意,由于使用了重复增广进行3次重复,在一个时期内只能看到三分之一的图像。

5. 结论

  • 使用convnet教师比使用Transformer提供更好的性能。
  • 对于Transformer,hard蒸馏显著优于soft蒸馏,即使只使用类token
  • Transformer随着网络加深,token之间逐渐变得更加相似,即过平滑
  • 在小数据集上,在没有Imagenet预训练的情况下,因为网络的多样性要低得多,所以从头开始训练的性能没有预训练过的好
  • 本文的实验证实,Transformer需要强大的数据增强:作者们评估的几乎所有数据增强方法都被证明是有用的。一个例外是dropout,作者将其排除在训练流程之外
  • Mixup和Cutmix这样的正则化可以提高transformer性能

猜你喜欢

转载自blog.csdn.net/weixin_43994864/article/details/123589610