【轻量化深度学习】知识蒸馏与NLP语言模型的结合

Knowledge Distillation

Student : Wenxuan Zeng

School : University of Electronic Science and Technology of China

Date : 2022.3.25 - 2022.4.3



参考论文: Distilling the Knowledge in a Neural Network

这篇论文是知识蒸馏的开山之作,发表于NIPS’14,非常值得我们去学习研究。所以我先从这篇论文入手去学习知识蒸馏,然后去学习如何使用知识蒸馏去压缩BERT模型。
在这里插入图片描述

1 Knowledge的定义

如果说知识就是模型中的参数,那么将难以迁移,因为两个不同的模型并没有一一对应的参数。教师网络预测结果中各个类别概率的相对大小隐式地包含了知识,在文中也称知识是从输入向量到输出向量的映射。举个直观的例子,对于小轿车的图片,模型会给出所有物体的预测概率,比如会有一部分概率是公交车,减小一部分概率是胡萝卜,那么教师网络就能教给学生网络这样的知识——这张图片大概率是一辆小轿车,不太可能是公交车或胡萝卜,并且这张图片更像公交车,而更不像胡萝卜。实际上,就是表明知识包含正确信息,同时也包含错误信息之间的相对关系。

2 Soft targets

在这里插入图片描述

一种将笨重模型的泛化能力迁移到小模型的方式就是将笨重模型所产生的类别概率作为soft targets来训练小模型。Soft targets中包含了较高的熵,所以提供了更为详细的信息;而hard target (one-hot encoding)则熵低,提供较少的信息。

什么是Soft/Hard targets?举个例子,在三分类问题中,小轿车的hard target也许能表达成这样:(0, 0, 1),那soft target也许是这样的:(0.1, 0.3, 0.6)。显然,soft targets中包含了更多的信息,比如之前提到的“这张图片更像公交车,而更不像胡萝卜”类似的相对信息。

Hard loss:
L h a r d = − 1 N ∑ i = 1 N l o g ( P ( x i ) ) L_{hard}=-\frac{1}{N}\sum^N_{i=1} log(P(x_i)) Lhard=N1i=1Nlog(P(xi))
Soft loss:
L s o f t = − 1 N ∑ i = 1 N ∑ j = 1 N y i j l o g ( P ( x i j ) ) L_{soft}=-\frac{1}{N}\sum^N_{i=1}\sum^N_{j=1} y_{ij} log(P(x_{ij})) Lsoft=N1i=1Nj=1Nyijlog(P(xij))

3 T-Softmax

复习一下softmax的作用,在做分类任务时,通过softmax将所有类别的概率压缩到 [0,1] 的范围内,并且概率值求和为1。Softmax表达式如下:

q i = e x p ( z i ) ∑ j e x p ( z j ) q_i=\frac{exp(z_i)}{\sum_j exp(z_j)} qi=jexp(zj)exp(zi)
T-Softmax就是在softmax的基础上,让每一个输入的z去除以T,如下所示:

q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_i=\frac{exp(z_i/T)}{\sum_j exp(z_j/T)} qi=jexp(zj/T)exp(zi/T)
这里的T就是蒸馏温度,当T=1时,就是softmax,当T>1时,得到soft targets。

结论:T越大,得到的预测结果越soft,各个类别的概率值越接近,所以其中包含的知识会更多。

在这里插入图片描述

4 知识蒸馏

4.1 蒸馏流程

下图是知识蒸馏的过程,教师网络在温度为t的时候训练,得到soft labels,学生网络是温度为t的时候训练,得到soft predictions,通过拟合soft labels和soft predictions,引导学生网络学习教师网络学到的知识**(比喻soft labels是老师的言传身教)。另外,学生网络在温度为1的时候训练,得到hard prediction,也就是one-hot encoding,然后用交叉熵损失函数与hard label计算出student loss(比喻hard label是课本知识)**。

在这里插入图片描述

在这里插入图片描述

4.2 Loss function

L = γ L h a r d + ( 1 − γ ) T 2 L s o f t L = \gamma L_{hard} + (1-\gamma)T^2 L_{soft} L=γLhard+(1γ)T2Lsoft

注意,在soft loss处需要乘上 T 2 T^2 T2,改变用于蒸馏的温度,硬目标和软目标的相对贡献大致保持不变。

在这里插入图片描述

4.3 预测值匹配是一种特殊形式的知识蒸馏

Model Compression (SIGKDD’06) 这篇论文中,作者通过知识迁移实现了模型的压缩,详细来说就是将教师网络和学生网络的logits求得MSE。而在本文中,作者说这种压缩方式是蒸馏的一个特例。

∂ C ∂ z i = 1 T ( q i − p i ) = 1 T ( e z i / T ∑ j e z j / T − e v i / T ∑ j e v j / T ) \frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T} (\frac{e^{z_i/T}}{\sum_j e^{z_j/T}} - \frac{e^{v_i}/T}{\sum_j e^{v_j}/T}) ziC=T1(qipi)=T1(jezj/Tezi/Tjevj/Tevi/T)
其中, q i q_i qi是学生网络预测的后验概率, p i p_i pi是教师网络预测的后验概率。

假设蒸馏温度T足够高,那么根据泰勒展开: e x = 1 + x e^x=1+x ex=1+x

∂ C ∂ z i = 1 T ( q i − p i ) = 1 T ( 1 + z i / T N + ∑ j z j / T − 1 + v i / T N + ∑ j v j / T ) \frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T}(\frac{1+{z_i}/T}{N+\sum_j {z_j}/T}-\frac{1+{v_i}/T}{N+\sum_j {v_j}/T}) ziC=T1(qipi)=T1(N+jzj/T1+zi/TN+jvj/T1+vi/T)
假设对于不同样本的logits期望为0,则

∂ C ∂ z i ≈ 1 N T 2 ( z i − v i ) \frac{\partial C}{\partial z_i} \approx \frac{1}{NT^2}(z_i-v_i) ziCNT21(zivi)
综上,如果蒸馏温度足够高,并且logits的期望为0,那么知识蒸馏等价于最小化 M S E = 1 / 2 ( z i − v i ) 2 MSE=1/2(z_i-v_i)^2 MSE=1/2(zivi)2​。

但是在实际情况下,并不能做到温度无穷大。 下图可以看出,温度太小的时候,很小的logits对应的softmax值被压到0,没有话语权,无法发挥蒸馏的效果;而温度太大的时候,所有类别的概率趋同,可能带来噪声。
在这里插入图片描述

温度T用多大比较好,这个需要靠经验决定,一般来说中间温度效果最佳。

在这里插入图片描述

4.4 知识蒸馏简单计算

在这里插入图片描述

5 实验设计

很奇妙的是,学生网络可以做到零样本学习,比如学生网络没有见过CNN中的平移不变性知识,但是仍然可以通过教师网络的知识迁移去学到。把数字3从学生网络的训练中抹掉,学生网络仍然可以从教师网络的知识中学到3的特征(作者手动调大了bias)。

6 知识蒸馏发展方向

  • 安排助教,安排多个老师、多个学生
  • 知识表示的表示(中间层)
  • 多模态、知识图谱、预训练大规模模型的蒸馏

7 知识蒸馏在NLP领域的研究

在这部分,我选出了几篇非常经典的BERT蒸馏的论文,然后对BERT蒸馏的思想进行学习,下面是我的一些学习记录。

7.1 Distilled BiLSTM

链接: Distilling Task-Specific Knowledge from BERT into Simple Neural Networks

方法: 教师模型采用fine-tune的BERT-LARGE模型,学生模型采用BiLSTM+ReLU,蒸馏目标是学生模型与hard labels的交叉熵+与BERT-LARGE的logits之间的MSE。
在这里插入图片描述

7.2 BERT-PKD

链接: Patient knowledge distillation for bert model compression (ACL’19)

方法: 不直接从模型的最后一层进行蒸馏,而是从教师模型的中间层提取知识进行蒸馏。本文提出了两种不同的蒸馏方式:Skip-k层的蒸馏方式和最后k层的蒸馏方式。

在这里插入图片描述

通过交叉熵损失函数定义学生和教师模型之间的预测值差距:

在这里插入图片描述

除了让学生模仿教师,还定义了任务相关的交叉熵损失函数:

在这里插入图片描述

另外,还定义了标准化后的隐藏状态的MSE loss作为损失函数:

在这里插入图片描述

7.2 DistillBERT

链接: Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter (NIPS’19)

方法:预训练阶段采用知识蒸馏技术压缩BERT,为了利用预训练时从大模型中学到的归纳偏差,引入了结合了语言建模、蒸馏和余弦距离损失的三元loss。

在这里插入图片描述

7.3 TinyBERT

链接:Tinybert: Distilling bert for natural language understanding (ACL’20)

方法:提出了two-stage learning framework,分别在预训练和fine-tune阶段蒸馏教师模型,得到了参数量减少7.5倍,速度提升9.4倍的4层BERT,效果可以达到教师模型的96.8%,同时这种方法训出的6层模型甚至接近BERT-base,超过了BERT-PKD和DistillBERT。本文提出注意力矩阵的蒸馏,用MSE作为损失函数拟合教师和学生的注意力矩阵。

在这里插入图片描述

同时对embedding layer和hidden layer都做知识蒸馏,同样采用MSE作为损失函数:

在这里插入图片描述

最后,用交叉熵损失函数去衡量教师和学生模型的logits差距:
在这里插入图片描述

综合上面提到的蒸馏目标,根据蒸馏的layer,决定采用哪个蒸馏的loss:

在这里插入图片描述

7.4 MobileBERT

链接: MobileBERT:a Compact Task-Agnostic BERT for Resource-Limited Devices (ACL’20)

方法: 采用了瓶颈结构和自注意力与前馈神经网络的平衡机制,将知识从教师模型蒸馏到学生模型,使模型具有更窄的宽度。(具体笔记在前面文档中的Paper Understanding部分有写)

在这里插入图片描述

7.5 MiniLM

链接: MiniLM: Deep Self-Attention Distillation for Task-Agnostic Compression of Pre-Trained Transformers (NIPS’20)

方法: 虽然之前的文章把模型蒸馏了个遍,从embeddin layer到hidden layer,又到attention layer,最后到prediction layer,但是本文仍然找了一个新的点去蒸馏,并取得了非常好的效果。这篇文章蒸馏self-attention模块,提出value之间的scaled dot-product (value-relation) 作为新的深度自注意力知识。另外,本文用了一个teacher assistant去辅助大模型的蒸馏。

在这里插入图片描述

自注意力矩阵之间的关系用KL散度来衡量:

在这里插入图片描述

下面是本文定义的value-relation,实际上就是对value做scaled dot product,然后用KL散度衡量两个VR矩阵:

在这里插入图片描述

结论: 这篇文章说“只蒸馏最后一层效果比layer-to-layer要好,而且不用严格去对应两个模型的每一层,只蒸馏最后一层也能提高学生的性能,并使学生具有更强的泛化能力”。

8 我对知识蒸馏的思考 ⭐

  • 通过前面的论文可以发现,一些论文认为耐心地蒸馏中间层会带来很好的效果,同时一些论文坚持只蒸馏最后一层会带来很好的效果。他们都在自己的论文中自圆其说,所以我认为,如何去自适应地选择特定layer蒸馏,是一个值得思考的方向
  • 最近我也在学习AutoML中的AutoLoss相关的论文,一些idea是通过教师网络(辅助任务的完成)指导学生网络(完成实际任务的模型)学习一个最佳的loss function。实现这个idea的方法会有很多,我觉得这种“教师指导学生以更好地完成目标任务”这种思想,是有前景的,或者说,是值得我们进一步去思考的。
  • 在我读到的论文中,大多都是“从教师到学生的知识迁移”,我在考虑能不能从“教学相长”这个角度去改进,也就是说不仅是老师交给学生知识,学生能否反馈给老师东西,教会老师一些道理呢?这种方式似乎又可以描述为“去掉教师模型,用两个或多个模型去相互学习、相互促进。”

前面的总结是之前写的,这里是来自两周后的补充:
后来我读到挺多论文在做teacher-student co-teaching过程,包括引入curriculum learning(课程学习)、generate pseudeo labels(生成伪标签),还有的paper用到self-distillation、孪生网络相互学习、多阶段蒸馏等等技术。
对于两个模型相互学习、相互蒸馏的过程,有论文专门提出了Mutual-Distillation及相关方法;对于自适应选取layer进行蒸馏的问题,也有论文专门做了Attention-based的distillation layer选取方案,将特定任务中具体layer的语义考虑进去。

猜你喜欢

转载自blog.csdn.net/qq_16763983/article/details/124430975