知识蒸馏(Distillation)相关论文阅读(1)——Distilling the Knowledge in a Neural Network(以及代码复现)

———————————————————————————————

《Distilling the Knowledge in a Neural Network》

Geoffrey Hintion


以往为了提高模型表现所采取的方法是对同一个数据集训练出多个模型,再对预测结果进行平均;但通常这样计算量过大。

引出一种模型压缩技术:Distillation;以及介绍了一种由一个或多个完整模型(full models)以及针对/细节/特殊模型(specialist models)的组合,来学习区分仅仅是完整模型会混淆的细节。比起expects的混合,这些specialist models可以并行训练,并且训练起来更快速。


Intorduce:

引喻:昆虫在幼虫时擅于从环境中汲取能量,但是成长为成虫后擅于其他方面,比如迁徙和繁殖等。

我们常常用相似的网络来训练不同需求的问题:对于像语音和对象识别这样的任务,训练必须从非常大的、高度冗余的数据集中提取结构,但是它不需要实时操作,并且可以使用大量的计算。

对网络新的理解:

如果能够简单的从数据中提取结构,我们应该乐意于训练非常复杂的网络。复杂网络可以理解为是由一些单个的模型或者是一个由强约束条件(例如dropout)训练得到的大型模型。一旦训练得到复杂网络,我们可以用不同的训练(即‘distillation’)来将知识从复杂网络转化成更使用于应用拓展的小模型。

知识蒸馏的难点:

如何改变网络结构但是同时保留同样的知识。抛开知识的实例化,知识可以看做为一个从输入向量到输出向量的只是地图。

对于复杂模型在大量类中区分的正常训练的目标,就是在于最大化正确答案的平均对数概率。但是缺点在于学习过程中会为所有的错误答案分配了概率,尽管这个概率很小。错误答案的相对概率反映了一个复杂网络是如何变得一般(泛化能力)的。

转化的可能性:

训练是的目标函数需要最大限度的反映正确目标,尽管如此模型还是通常被训练得在测试数据上表现的最优,但真实目标其实是需要在新的数据集上表现良好。当我们从大模型到小模型做知识蒸馏的时候,我们可以像训练大模型一下训练好小模型。一个复杂的模型具有良好的泛化能力是因为它通常是许多不同模型的平均,这样通过蒸馏的方式训练出的小模型会比传统训练出来的小模型在测试集上表现更好。

利用复杂模型的泛化能力转化为小模型的一种显而易见的方法是:

将复杂模型生成的分类概率作为训练小模型的“软目标”。在这个转移阶段,我们可以使用相同的训练集或其他的“转化”训练集。当复杂的模型是由一组更简单的模型组成时,我们可以使用它们各自的预测分布的算术或几何平均值作为软目标。当软目标具有较高的熵值时,相对“硬目标,每个训练用例所提供的信息要比硬性指标多得多,它每次训练可以提供更多的信息和更小的梯度方差。因此小模型可以比原始的复杂模型更容易地训练,而且使用的学习速率要高得多。

MNIST实例:

像MNIST这种任务,复杂模型可以给出很完美的结果,大部分信息分布在小概率的软目标中。比如一张2的图片被认为是3的概率为0.000001,被认为是7的概率是0.000000001。Caruana用logits(softmax层的输入)而不是softmax层的输出作为“软目标”。目标是使得复杂模型和小模型分别得到的logits的平方差最小。“蒸馏法”:第一步,提升softmax表达式中的调节参数T,使得复杂模型产生一个合适的“软目标”  第二步,采用同样的T来训练小模型,使得它产生相匹配的“软目标。

并且发现,比起使用未标注的数据集原始的训练集更好,尤其是在目标函数中加了一项的时候,能够让小模型预测正确的同时尽量匹配软目标。但小模型是不能完全匹配软目标的,正确结果的错误方向反而是有帮助。


Distillation:

修改后的softmax公式为:

T就是一个调节参数,通常为1;T的数值越大则所有类的分布越‘软’(平缓)。

一个简单的知识蒸馏的形式是:用复杂模型得到的“软目标”为目标(在softmax中T较大),用“转化”训练集训练小模型。训练小模型时T不变仍然较大,训练完之后T改为1。 

当正确的标签是所有的或部分的传输集时,这个方法可以通过训练被蒸馏的模型产生正确的标签。一种方法是使用正确的标签来修改软目标,但是我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是带有软目标的交叉熵,这种交叉熵是在蒸馏模型的softmax中使用相同的T计算的,用于从繁琐的模型中生成软目标。第二个目标函数是带有正确标签的交叉熵。这是在蒸馏模型的softmax中使用完全相同的逻辑,但在T=1下计算。我们发现,在第二个目标函数中,使用一个较低权重的条件,得到了最好的结果。由于软目标尺度所产生的梯度的大小为1/T^2,所以在使用硬的和软的目标时将它们乘以T^2是很重要的。这确保了在使用T时,硬和软目标的相对贡献基本保持不变。


———————————————————————————————


笔者个人理解以及Pytorch代码实现:

可能读者看到此处发现并不特别清楚论文中具体的蒸馏步骤以及T参数的意义,以下对几个关键点的理解做出个人的解释,欢迎指导和讨论:

      1. T参数是什么?有什么作用?

        T参数为了对应蒸馏的概念,在论文中叫的是Temperature,也就是蒸馏的温度。T越高对应的分布概率越平缓,为什么要使得分布概率变平缓?举一个例子,假设你是每次都是进行负重登山,虽然过程很辛苦,但是当有一天你取下负重,正常的登山的时候,你就会变得非常轻松,可以比别人登得高登得远。

        同样的,在这篇文章里面的T就是这个负重包,我们知道对于一个复杂网络来说往往能够得到很好的分类效果,错误的概率比正确的概率会小很多很多,但是对于一个小网络来说它是无法学成这个效果的。我们为了去帮助小网络进行学习,就在小网络的softmax加一个T参数,加上这个T参数以后错误分类再经过softmax以后输出会变大(softmax中指数函数的单增特性,这里不做具体解释),同样的正确分类会变小。这就人为的加大了训练的难度,一旦将T重新设置为1,分类结果会非常的接近于大网络的分类效果

     2. soft target(“软目标”)是什么?

        soft就是对应的带有T的目标,是要尽量的接近于大网络加入T后的分布概率。

     3. hard target(“硬目标”)是什么?

         hard就是正常网络训练的目标,是要尽量的完成正确的分类。

     4. 两个目标函数究竟是什么?

        两个目标函数也就是对应的上面的soft target和hard target。这个体现在Student Network会有两个loss,分别对应上面两个问题求得的交叉熵,作为小网络训练的loss function。

     5. 具体蒸馏是如何训练的?

        Teacher:  对softmax(T=20)的输出与原始label求loss。

        Student:(1)对softmax(T=20)的输出与Teacher的softmax(T=20)的输出求loss1。

                         (2)对softmax(T=1)的输出与原始label求loss2。

                         (3)loss = loss1+loss2


在弄清楚上面的问题以后,我们就可以进行代码复现了,笔者选择的是Pytorch,具体代码和效果对比会稍后整理一下上传至本人的github,先占坑~


猜你喜欢

转载自blog.csdn.net/Lucifer_zzq/article/details/79489248