CVPR 2022|解耦知识蒸馏!旷视提出DKD:让Hinton在7年前提出的方法重回SOTA行列!...

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

转载自:机器之心  |  旷视科技等

与主流的feature蒸馏方法不同,本研究将重心放回到logits蒸馏上,提出了一种新的方法「解耦知识蒸馏」,重新达到了SOTA结果,为保证复现该研究还提供了开源的蒸馏代码库:MDistiller。

1 研究摘要

近年来顶会的 SOTA 蒸馏方法多基于 CNN 的中间层特征,而基于输出 logits 的方法被严重忽视了。饮水思源,本文中来自旷视科技 (Megvii)、早稻田大学、清华大学的研究者将研究重心放回到 logits 蒸馏上,对 7 年前 Hinton 提出的知识蒸馏方法(Knowledge Distillation,下文简称 KD)[1] 进行了解耦和分析,发现了一些限制 KD 性能的重要因素,进而提出了一种新的方法「解耦知识蒸馏」(Decoupled Knowledge Distillation,下文简称 DKD)[2],使得 logits 蒸馏重回 SOTA 行列。

同时,为了保证复现和支持进一步研究,该研究提供了一个全新的开源代码库 MDistiller,该库涵盖了 DKD 和大部分的 SOTA 方法,并不断进行更新维护,欢迎大家试用并提供宝贵的反馈意见。

6e6791f1ed62903bff17a33dd782772a.png

  • 论文链接:https://arxiv.org/abs/2203.08679

  • 代码链接:https://github.com/megvii-research/mdistiller

2 研究动机

e8b0ada2f34c11e1c9fe38dc2e1dd482.png

上图是大家熟知的 KD 方法,KD 用 Teacher 网络和 Student 网络的输出 logits 来计算 KL Loss,从而实现 dark knowledge 的传递,利用 Teacher 已经学到的知识帮助 Student 收敛得更好。在 KD 之后,更多的基于中间特征的蒸馏方法不断涌现,不断刷新知识蒸馏的 SOTA。但该研究认为,KD 这样的 logits 蒸馏方法具备两点好处:

1. 基于 feature 的蒸馏方法需要更多复杂的结构来拉齐特征的尺度和网络的表示能力,而 logits 蒸馏方法更简单高效;

2. 相比中间 feature,logits 的语义信息是更 high-level 且更明确的,基于 logits 信号的蒸馏方法也应该具备更高的性能上限,因此,对 logits 蒸馏进行更多的探索是有意义的。

该研究尝试一种拆解的方法来更深入地分析 KD:将 logits 分成两个部分(如图),蓝色部分代表目标类别(target class)的 score,绿色部分代表非目标类别(Non-target class)的 score。这样的拆解使得我们可以重新推导 KD 的 Loss 公式,得到一个新的等价表达式,进而做更多的实验和分析。

2.1 符号定义

这里只写出关键符号定义,更具体的定义请参考论文正文。

首先,该研究将第 i 类的分类概率表示为(其中bf062b8e5a94d39539fe7e0e1f6f3268.png表示网络输出的 logits):

8c986de26a2919bef11d0ac5e3fddd2c.png

为了拆解分类网络输出的 logits,该研究接下来定义了两种新的概率分布c469fab11b550bfc3317a445139cd91d.png

1. 目标类 vs 非目标类的二分类分布,该概率分布和分类监督信号高度耦合。该分布包含两个元素:目标类概率和全部非目标类概率,分别表示为:

136973a34616982ab333d442eae68620.png

2. 非目标类内部竞争的多分类分布eaf2f7304a14162c799c36337fc037a3.png,也就是在预测样本为非目标类的前提下每个类各自的概率(总和为 1)。这个概率分布和分类的监督信号是不相关的,换句话说,从这个概率分布中无法得知目标类上的预测置信度,其表达式为:

ac41f74bc1dfe680ea0108aafe0f1e62.png

根据上述定义,可以得到一个显然的数学关系:ab014182786f6a13daa80ccd271a4aa9.png。这些定义和数学关系将帮助我们得到 KD Loss 的一个新的表达形式。

2.2 重新推导 KD Loss

首先,KD 的 Loss 定义如下:

18920df3aa448a8bc5bc57792c6715d6.png

然后根据公式(1)和(2),我们可以将其改写为:

c7df3e2b72f8209db5ed05062c3b5093.png

可以观察到,式中的第一项dfa6905b505134bab262f1dfeb49e073.png只牵涉到了目标类别 vs 非目标类别的二分类概率分布3292a5edfe0fd76610c16b1fea940e66.png,第二项68c02d123617e03303b89a1491c19afc.png牵涉到了非目标类概率分布的 KL 散度918b43614cecbea0c0ded2c274893328.png和权重379d14928de884cd9c141fda938f40c5.png。该研究将第一项命名为目标类别知识蒸馏 Target Class Knowledge Distillation(下文简称 TCKD),将第二项中的 KL 散度命名为非目标类别知识蒸馏 Non-target Class Knowledge Distillation(下文简称 NCKD)。至此,该研究完成了对 KD Loss 的拆分,将其分成了两个可单独使用的部分,并可以分析其各自的作用:

c4058d4ae43b2ee5f8676904ee661b49.png

3 启发式探索

首先,该研究对 TCKD 和 NCKD 做了消融实验,观察它们对蒸馏性能的影响;接着,他们分别探索 TCKD 和 NCKD 的作用;最后,研究者做了一些启发式的讨论。

3.1 单独使用 TCKD/NCKD 训练

436c47efe71bfffa890cd1b2059915be.png

如表 1 所示,我们可以观察到:

1. 同时使用 TCKD 和 NCKD(等同于 KD),有不错的性能提升;

2. 单独使用 TCKD 进行蒸馏,会对蒸馏效果产生较大的损害(这一点在补充材料中有详细讨论,主要和蒸馏温度 T 相关);

3. 单独使用 NCKD 进行蒸馏,和 KD 的效果是差不多的,甚至有时会更好;

基于这些观察可以推出两个初步结论:

1.TCKD 是没用的,甚至在单独使用时可能是有害的;

2.NCKD 可能是 KD 生效的主要原因;

接下来该研究就这两个初步的结论进行了进一步的分析。

3.2 TCKD:传递样本难度相关的知识

TCKD 作用于目标类的二分类概率分布上,这个概率的物理含义是「网络对样本的置信度」。比如:如果一个样本被 Teacher 学会了,会产生类似[0.99, 0.01] 的 binary 概率,而如果一个样本比较难拟合,则会产生类似 [0.6, 0.4] 的 binary 概率。所以该研究猜测:TCKD 传递了和样本拟合难度相关的知识,当训练集拟合难度高时才会起到作用。为了证明这一点,该研究设计了三组实验来增加 CIFAR-100 的训练难度,观察 TCKD 是否有效:

更强的数据增广:

65ce48afbf42f0497e2c0bf4638a217c.png

以表 2 中的 ShuffleNet-V1 为例,在使用 AutoAugment 的情况下,训练集难度有了明显提升,此时仅仅使用 NCKD 只能达到 73.8% 的 student 准确率,而同时使用 TCKD 和 NCKD 可以将 student 准确率提升至 75.3%。

更 Noisy 的标签:

33acab3025aa7faa03f91e94e169b834.png

表 3 中,该研究通过控制 noisy ratio 对数据集的标签引入不同程度噪声,ratio 越大表示噪声越大。可以看到,随着数据集的噪声变大,单独使用 NCKD 的效果变得越来越差,同时引入 TCKD 的增益也越来越大。说明在越难学的数据上,TCKD 的作用就会越明显。

更难的数据集:

16e6b54e12713709c0685b3e01e6796e.png

ImageNet 是一个比 CIFAR-100 更困难的数据集,所以该研究在 ImageNet 上也进行了尝试。从表 4 可以看出,在 ImageNet 上只使用 NCKD 的效果也是没有同时使用 TCKD 和 NCKD 要好的。

总结

三组实验都反映出,当训练数据拟合难度变高时(无论是数据本身难度、还是噪声和增广带来的难度),TCKD 能提供更有效的知识,对蒸馏性能的提升也越高,这些实验在一定程度上说明了 TCKD 确实是在传递有关样本拟合难度的知识,印证了该研究的想法。

3.3 NCKD:被抑制的重要成分

表 1 中反映出的另一个有趣的现象是:只使用 NCKD 也能取得令人满意的蒸馏效果,甚至可能比 KD 更好。这样的现象反映出:非目标类别上的 logits 中蕴含的信息,才是最主要的 dark knowledge 成分。

然而当回顾 KD 的新表达式时,发现 NCKD 对应的 loss 是和权重71299112691a306ffd5d67768540bf0d.png耦合在一起的。换言之,如果 teacher 网络的预测越置信,NCKD 的 loss 权重就更低,其作用就会越小。而该研究认为,teacher 更置信的样本能够提供更有益的 dark knowledge,和 NCKD 耦合的3b3dc2d3f4631f5829e8513814efd0c3.png权重会严重抑制高置信度样本的知识迁移,使得知识蒸馏的效率大幅降低。为了证明这一点,该研究做了如下实验:

1. 依据 teacher 模型的置信度,该研究对训练集上的样本做了排序,并将排序后的样本分成置信(置信度 top-50%)和非置信 (剩余) 两个批次;

2. 训练时,对全部样本使用分类 Loss,并只对置信批次 / 非置信批次使用 NCKD Loss;

921d3da77ac11e7bec6c78c6c92dbf9e.png

实验结果如表 5 所示,0-50% 表示置信批次,50-100% 表示非置信批次。第一行是在整个训练集上做 NCKD 的结果,第二行表示只对置信批次做 NCKD,第三行表示只对非置信批次做 NCKD。显然,置信批次上使用 NCKD 带来了更主要的涨点,说明置信度更高的样本对蒸馏的训练过程是更有益的,因此是不应该被抑制的。

3.4 启发

至此,该研究完成了对 KD Loss 的解耦,并且分析了两个部分各自的作用。所有结果都表明,TCKD 和 NCKD 都有自己的重要作用,然而,研究注意到了在原始的 KD Loss 中,TCKD 和 NCKD 是存在不合理的耦合的:

1. 一方面,NCKD 和678c954f643f9f8e3df59a7fad19cdf3.png耦合,会导致高置信度样本的蒸馏效果大打折扣;

2. 另一方面,TCKD 和 NCKD 是耦合的。然而这两个部分传递的知识是不同的,这样的耦合导致了他们各自的重要性没有办法灵活调整。

4 Decoupled Knowledge Distillation

80bdd8d6c6be76637cf3117ff51023cd.png

根据推导和启发式探索,该研究提出了一种新的 logits 蒸馏方法“解耦知识蒸馏(DKD)”,来解决上一章提出的两个问题,如上图所示。DKD 的 Loss 表达式如下:

c33e1e168894a5e46182c8b115272231.png

和 KD Loss 相比,该研究将限制 NCKD 的权重8da378970adc195c058471528633a8ff.png替换为了d4189096bd2ddf595944c96215506b85.png,并给 TCKD 设置了一个权重95106fa802b21f655f229de1db36b23b.png。DKD 可以很好地解决刚才提到的两个问题:一方面,TCKD 和 NCKD 被解耦,它们各自的重要性可以独立调节;另一方面,对于蒸馏更重要的 NCKD 也不会再被 Teacher 产生的高置信度抑制,大大提高了蒸馏的灵活性和有效性。DKD 的伪代码如下:

f2b6287e782e09cf418374040bc2220e.png

5 实验结果

5.1 Decoupling 带来的好处

9b0f9b3d6c1ceb711f00dfca96ef5a51.png

首先该研究通过 ablation study 验证了 DKD 的有效性,上面的表格表明:

1. 解耦a397d76843f6aa905bd7798db5f4a1ea.png和 NCKD,也就是把5cf651c4a34c7ee2ade14fd7c3af8ab1.png设置为 1.0,可以将 top-1 accuracy 从 73.6% 提升至 74.8%;

2. 解耦 NCKD 和 TCKD 的权重,即进一步调节71e78f96532b4dd3218c4802510f2295.png的数值,可以将 top-1 accuracy 从 74.8% 进一步提升至 76.3%;

这些实验结果说明 DKD 的解耦确实能带来显著的性能增益,这一方面证明了 KD 确实存在刚才提到的两个问题,另一方面也证明了 DKD 的有效性。此外,这个表格也证明了44ccb229ceab4a9c5cce337f129fa24b.png对超参数是不敏感的,把cfe2193731caa152c4b4f4393cf1ca1c.png设置为 1.0 就可以取得令人满意的效果,所以在实际应用中只需要调节1cfdd4441c7891dea67b4b6b992cd8f7.png即可。同时,cafc3e942876b098c168d45cb9d84f73.png也不是一个敏感的超参数,在 4.0-10.0 的范围内,都可以取得令人满意的蒸馏效果。

5.2 图像分类

cd7970b9aabbae18c653794d5d00ae25.png

1d459d309edf6118b0bd8ce5500e8111.png

表 6~9 中提供了 DKD 在 CIFAR-100 和 ImageNet-1K 两个分类数据集上的蒸馏效果。和 KD 相比,DKD 在所有数据集和网络结构上都有明显的性能提升。此外,与过去最好的特征蒸馏方法(ReviewKD)相比,DKD 也取得了接近甚至更好的结果。DKD 成功使得 logits 蒸馏方法重新回到了 SOTA 的阵营中。

5.3 目标检测

e058d4c290533c29315cd0ba56055320.png

该研究也在目标检测任务(MS-COCO)上验证了 DKD 的性能。如表 10 所示,在 Detector 蒸馏中,DKD 的结果虽不如特征蒸馏的 SOTA 性能,但是依然稳定地超过了 KD 的性能。而将 DKD 和特征蒸馏方法组合起来,也可以进一步提高 SOTA 结果。

关于这一点:过去的一些工作证明了,Detection 任务非常依赖特征的定位能力,这在 Detector 蒸馏中也是成立的(如 [5] 中提到了,feature mimicking 是非常重要的)。而 logits 并不能提供 location 相关的信息,无法对 Student 的定位能力产生帮助,因此在 Detection 任务中,特征蒸馏相比 logits 蒸馏存在机制上的优势,这也是 DKD 无法超过特征蒸馏 SOTA 的原因。

6 扩展性实验和可视化

6.1 训练效率

01e1a04e6ff99c7f379c6bb5ea1ed92c.png

logits 蒸馏的好处之一是训练效率高。为了证明这一点,该研究可视化了 SOTA 蒸馏方法的训练开销。图 2 的 X 轴是每个 batch 的训练时间,Y 轴是 student 的 top-1 accuracy。显然,logits 蒸馏(KD 和 DKD)所需的训练时间是最少的,并且 DKD 用了最少的时间获得了最好的蒸馏效果。图 2 中的表格也提供了训练时间和训练所需的额外参数量,和 KD 一样,DKD 也并没有额外引入参数量,同时训练时间也几乎没有增加。logits 蒸馏的优越性显而易见。

6.2 提升大 Teacher 模型蒸馏效果

484dace51332bd377e9968363d25feb1.png

过去的一些蒸馏工作发现了一个有趣的现象:大模型并不一定是好的 Teacher 网络。对于该现象,研究者提供了一个可能的解释:大模型的 model capacity 很大,这会导致大模型产生更高的0afdc29fb6a023fc68b63395131d4ecd.png,进而导致的 NCKD 被抑制得更严重。过去的一些工作也可以基于这一点解释,如 ESKD [4] 引入了 early-stopped teacher 来缓解这一问题,这可能是因为 early-stopped 模型还没有充分拟合训练集,b311fe65027b8999345b89f3068352c5.png还比较小,所以对 NCKD 的抑制不是很严重。

为了证明该观点,研究者也进行了一系列的对比实验。如表 11 和表 12 所示,当使用 DKD 时,大模型蒸馏效果变差的问题被显著改善。该研究希望这一点可以为后续的工作提供一些 insight。

6.3 特征迁移性

c4514026754d9c6ec106d3ed9ca1932e.png

这里该研究尝试将 DKD 训练得到的 student 网络进行特征迁移。如表 13 所示,研究者将在 CIFAR-100 上训练的 student 迁移到了 STL-10 和 TinyImageNet 两个数据集上,在众多的蒸馏方法中,DKD 取得了最好的迁移效果。

6.4 可视化

这里研究者提供了两种可视化。图 3 中,与 KD 相比,DKD 的样本聚得更加紧凑,说明 DKD 帮助 student 网络学到了更加可区分的特征。图 4 中,研究者计算了 teacher 网络和 student 网络输出 logits 的相似度,和 KD 相比,DKD 训练后的 student 产生的 logits 会更像 teacher 产生的 logits,说明 teacher 的知识被更好地利用了。

311044183ce454f60c943ac88906cb14.png

7 改进方向

6fa9c13a5c0d79b9caa18e9f10f86b61.png的自适应调整:DKD 目前还需要手工调整98ea2eab1d9eb720fbc5b6cb16e01923.png的值才能达到最佳的蒸馏效果,该研究希望可以通过一些训练过程中的统计量实现对0e311fa0ff02303417252480148b5218.png的自适应调节(对于这一点,该研究已经有了初步的探索,详情可见补充材料)。

8 开源代码库 MDistiller

c15b614d89d2059865bb53e6917cfd64.png

为了保证复现和进一步的探索,该研究还开源了一个知识蒸馏的 codebase MDistiller。该 codebase 涵盖了大部分的 SOTA 方法,同时支持两种蒸馏关注的主要任务,图像分类和目标检测。该研究希望 MDistiller 可以为后续的研究者们提供一套可靠的 baseline,因此会提供长期支持。

参考文献

[1] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. In arXiv:1503.02531, 2015.

[2] Borui Zhao, Quan Cui, Renjie Song, Yiyu Qiu, and Jiajun Liang. Decoupled knowledge distillation. In CVPR, 2022. 

[3] Pengguang Chen, Shu Liu, Hengshuang Zhao, and Jiaya Jia. Distilling knowledge via knowledge review. In CVPR, 2021. 

[4] Jang Hyun Cho and Bharath Hariharan. On the efficacy of knowledge distillation. In ICCV, 2019. 

[5] Tao Wang, Li Yuan, Xiaopeng Zhang, and Jiashi Feng. Distilling object detectors with fine-grained feature imitation. In CVPR, 2019. 

 
  
 
  
 
  

ICCV和CVPR 2021论文和代码下载

后台回复:CVPR2021,即可下载CVPR 2021论文和代码开源的论文合集

后台回复:ICCV2021,即可下载ICCV 2021论文和代码开源的论文合集

后台回复:Transformer综述,即可下载最新的3篇Transformer综述PDF

目标检测和Transformer交流群成立
扫描下方二维码,或者添加微信:CVer6666,即可添加CVer小助手微信,便可申请加入CVer-Transformer或者目标检测 微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer等。
一定要备注:研究方向+地点+学校/公司+昵称(如Transformer或者目标检测+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

▲扫码或加微信: CVer6666,进交流群
CVer学术交流群(知识星球)来了!想要了解最新最快最好的CV/DL/ML论文速递、优质开源项目、学习教程和实战训练等资料,欢迎扫描下方二维码,加入CVer学术交流群,已汇集数千人!

▲扫码进群
▲点击上方卡片,关注CVer公众号

整理不易,请点赞和在看

猜你喜欢

转载自blog.csdn.net/amusi1994/article/details/124138724