【论文阅读】Training data-efficient image transformers & distillation through attention

本文主要对Facebook最近提出来的DeiT模型进行阅读分析。
在这里插入图片描述


动机:DeiT解决什么问题?

现有的基于Transformer的分类模型ViT需要在海量数据上(JFT-300M,3亿张图片)进行预训练,再在ImageNet数据集上进行fune-tuning,才能达到与CNN方法相当的性能,这需要非常大量的计算资源,这限制了ViT方法的进一步应用。
在这里插入图片描述
Data-efficient image transformers (DeiT) 无需海量预训练数据,只依靠ImageNet数据,便可以达到SOTA的结果,同时依赖的训练资源更少(4 GPUs in three days)。
在这里插入图片描述
上图左边是DeiT与ViT, EfficientNet的结果对比图,右边是几种DeiT模型采用的结构。


方法

DeiT如何实现前面介绍的结果呢?主要是以下两个方面:
1) 采用合适的训练策略包括optimizer, data augmentation, regularization等,这一块该文主要是在实验部分介绍;
2)采用蒸馏的方式,结合teacher model来引导基于Transformer的DeiT更好地学习(这个论文的方法部分主要是介绍的这个);

假设已经获取得到一个较好的分类模型(teacher),采用蒸馏的方式也很简单,相对于ViT主要是增加了一个distillation token,其对应的token输出值与teacher model的输出值尽可能接近,下图表示DeiT方法的示意图。
在这里插入图片描述
针对distillation的类型,主要有两种方式,soft distillation和hard distillation,本质区别是,soft是限制student和teacher模型输出的类别分布尽可能接近,hard是限制两种模型输出的类别标签尽可能接近。

——soft distillation
在这里插入图片描述
这里用的KL散度计算分布之间的相似性。

——hard distillation
在这里插入图片描述
这里需要用 a r g m a x argmax argmax函数。从后面的实验可以看出hard distillation效果会更好一些,但因为使用了 a r g m a x argmax argmax函数,teacher model模型输出的信息会丢失很多信息,为什么hard要比soft好,本文里没有展开解释。


实验效果

关于模型蒸馏(Distillation)的实验

下面是在ImageNet上的实验结果:
在这里插入图片描述
其中符号 ↑ 384 \uparrow 384 384表示采用本文作者NIPS2019的工作[2],在224x224的图像上进行预训练,在384x384图像上进行finetune.

一个有意思的现象是,使用性能相对较差的RegNetY-4/8GF为Teacher,蒸馏后DeiT-B的结果比Teacher还要高;其中相对RegNetY-4GF提升了2.7个点,最为明显。对此,本文作者如下解释:

The fact that the convnet is a better teacher is probably due to the inductive bias inherited by the transformers through distillation, as explained in Abnar et al [3].

换句话讲就是,CNN是有inductive bias的,例如局部感受野,参数共享等,这些设计比较适应于图像任务,这里将CNN作为teacher,可以通过蒸馏,使得Transformer学习得到CNN的inductive bias,从而提升Transformer对图像任务的处理能力。

以下是采用不同的distillation方式,在ImageNet中的结果,实验表明hard distillation 效果好于soft的方式,而在测试时,同时使用class和distil embedding,效果会更好。
在这里插入图片描述

扫描二维码关注公众号,回复: 12440551 查看本文章

在ImageNet上的结果

下图表示DeiT在ImageNet最终的结果,注意这里面列的ViT的效果是只用ImageNet进行训练的结果,并没有用到JFT-300M数据集,整体上效果还是不错的。
在这里插入图片描述

Training details & ablation

这部分就是在方法部分讲的第一点,本文详细阐述了采用怎样的方式能够将Transformer训练好。
在这里插入图片描述
1)optimizer
采用的是adamw,即带weight decay的adam,这里它做了与SGD的对比,但没有与adam对比。

2) data augmentation
本文指出,相对于引入先验的模型来说,如CNN,transformers一般都需要一个更加大的数据集,对此就依赖于大量的数据扩充操作,这里重点采用的数据扩充方式是Rand-Augment, Mixup, CutMix。下图是一些相关操作的示意图。
在这里插入图片描述
3) regularization
这里指出采用Random Erasing和Stochastic depth等方式有助于模型的收敛,尤其是采用较深的模型时。

Random Erasing[4]:随机选择一个区域,然后采用随机值进行覆盖。
在这里插入图片描述
Stochastic depth[5]: 随机失活一些卷积层,只保留 shortcut 通路的方式随机跳过 一些 Residual Blocks
在这里插入图片描述


总结

DeiT核心思想是采用蒸馏的方式,使得基于transformer的模型能够学习得到基于CNN模型的一些inductive bias,从而提升对图像类型数据的处理能力。蒸馏的相关操作是值得学习和借鉴的。

此外,实验部分中的Training details也非常值得借鉴,如何将transformer进行有效训练,在其他任务是也是可以利用的。


参考资料

[1] Touvron H, Cord M, Douze M, et al. Training data-efficient image transformers & distillation through attention[J]. arXiv preprint arXiv:2012.12877, 2020.

[2] Touvron H, Vedaldi A, Douze M, et al. Fixing the train-test resolution discrepancy[J]. arXiv preprint arXiv:1906.06423, 2019.

[3] Samira Abnar, Mostafa Dehghani, and Willem Zuidema. Transferring inductive biases through knowledge distillation. arXiv preprint arXiv:2006.00555, 2020.

[4] Zhun Zhong, Liang Zheng, Guoliang Kang, Shaozi Li, and Yi Yang. Random erasing data augmentation. In AAAI, 2020.

[5] Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra, and Kilian Q. Weinberger. Deep networks with stochastic depth. In European Conference on Computer Vision, 2016.

[6] 知乎文章,《想读懂YOLOV4,你需要先了解下列技术(一)》,《想读懂YOLOV4,你需要先了解下列技术(二)》:详细系统地总结了很多数据增强/扩充、特征增强、归一化、网络感受野增强技巧、注意力机制和特征融合技巧等方法。

猜你喜欢

转载自blog.csdn.net/yideqianfenzhiyi/article/details/113444303