Distilling the Knowledge in a Neural Network[论文阅读笔记]

本文是Hinton大神在网络压缩与迁移学习方向挖的一个坑
原文链接Distilling the Knowledge in a Neural Network


这种方法感觉受到了ensemble的启发,利用大型(teacher net)网络提取先验知识,将这种先验知识作为soft target让微型网络(student network)学习,有点像Boost中第一个分类器学到后调整weight让第二个分类器学习。当然相似中也有不同之处

Introduction

主旨就是从复杂网络(teacher net)中抽取训练数据的分布‘教给’简易网络(student net)

  • 复杂网络可以从
    1. ensmeble models 中学习
    2. 从单独的大型网络(使用正则项或dropout)中学习

复杂网络通常通过最大化log probability去学习多分类,这有个副作用是模型通常会赋予非正确答案一定的概率,即使这些概率都很小但是有一些是明显大于其他的(类似于softmax)。


  • 通过复杂网络产生的分类概率分布作为soft target来训练小模型

在transfer的过程中(即cumbersome network转向small network的时候)可以使用同样的训练集也可以使用单独的训练集(猜想:可能使用不同的训练集的时候效果会更好;但是文中说使用原始数据集效果更好)


值得注意的是 对于Mnist数据集中总是产生具有信心的正确结果,很多非正确数字的概率都是非常小的,举个栗子,对于正确的2来说,被分类为3的概率为106

被分类为7的概率为 109

。在这种情况下soft target的帮助就很小,所以Caruana在其文章中使用log函数解决这一问题。本文利用蒸馏解决这个问题(蒸馏会使得target变得更加soft)

Distillation(蒸馏)

 加入蒸馏后的softmax函数:

qi=exp(zi/T)jexp(zj/T)(1)

这里T是超参数,文中说是‘温度’,经过该参数之后的softmax会更加平滑,分布更加均匀而大小关系不变。T参数在设置为1的时候就是平常的softmax函数。
在知识转换阶段,设置复杂网络与简易网络相同的T参数。在此之后再从新将T设置为1


根据这篇博客再加入T之后的softmax的概率分布更加平滑,作为soft target时简易(student)网络能学到更多东西
这里写图片描述

另外使用matlab做蒸馏
这里写图片描述


在hinton的slides中也可以看到更加平滑的概率slides

这里写图片描述


具体蒸馏结构如下图所示:

这里写图片描述

这里λ

是hard target与soft target的权重

1、训练大模型:先用hard target,也就是正常的label训练大模型。
2、计算soft target:利用训练好的大模型来计算soft target。也就是大模型“软化后”再经过softmax的output。
3、训练小模型,在小模型的基础上再加一个额外的soft target的loss function,通过lambda来调节两个loss functions的比重。
4、预测时,将训练好的小模型按常规方式(右图)使用。
这里参考

将hard target的参数设置的比较小可以得到较好的结果,在做梯度下降时需要将soft target的梯度乘以1/T2

,这是因为将梯度的规模保持与hard target一样

ditiliation在特殊情况下相当于logits

首先给出交叉上的求导公式

Czi=1T(qipi)=1T(eziTjezjT)1T(eviTjevjT)(2)

其中vi是cumbersome model产生的logitsvi相当于真实分布,zi是distilled model
公式推导:
qi=1T(eziTjezjT)(2.1)

pi=1T(eviTjevjT)(2.2)

cross entropy=j pj×log qj(2.3)

对交叉熵求导即可得到公式(2)
并且作者给出当T非常大的时候distillation优化的目标等价于Caruana提取的对logits的平方误差求最优化
Czi1T(1+ziTN+jzjT)(1+viTN+jvjT)(3)

当我们假设logits是zero-means的则 jzj=jvj=0
Czi1NT2(zivi)(4)

实验部分

初步试验 Mnist数据集

  • 训练一个有两层具有1200个单元的隐藏层的大型网络(使用dropout和weight-constraints作为正则)值得注意的一点是dropout可以看做是share weights 的ensemble models

  • 另外一个小一点的网络具有两层800个单元隐藏层没有正则

结果是第一个网络test error 67个,第二个是146个;再加入soft target并且T设置为20之后小型网络test error达到74个


另外需要注意一点的是:

When the distilled net had 300 or more units in each of its two hidden layers, all temperatures above 8 gave fairly similar results. But when this was radically reduced to 30 units per layer, temperatures in the range 2.5 to 4 worked significantly better than higher or lower temperatures.

 该现象可能说明将概率设置的过于soften可能会导致一些问题尤其是在拟合能力较差的网络中


另外的重要发现

  • 遗漏数据集所有的数字3做训练后,distilled model只有206个test error只有206 其中133是对3的辨识错误(测试集中3有1010个),很多错误是bias过低导致(?)到bias增加到3.5时distilled model 给了109个error,其中14个是3
  • 训练集只有7和8时,distilled model有47.3%的测试错误,到bias降低7.6时优化了,降低了13.2%测试错误。
  • 一个问题:这样手动修正bias增加正确率是否有普适的意义

剩下的就是在语音数据及的实验以及大型数据及JFT的实验。具体可以参考原文。


转载: https://blog.csdn.net/qq_22749699/article/details/79460817

猜你喜欢

转载自blog.csdn.net/jdbc/article/details/80657822