Recently, in the process of learning the algorithm for noise reduction and dealing with bad weather, I came into contact with knowledge distillation. As a general algorithm for deep learning, this algorithm is not only widely used in natural language processing, but also widely sought after in fields such as computer vision.
overview
Simply put, knowledge distillation is to extract a large teacher network into a small student network, that is, to complete a process of knowledge transfer. This teacher network may be a collection of many networks, which is a very bloated model. In order to facilitate the It is deployed on some platform devices with limited computing power, such as mobile phones and autonomous driving platforms, and condenses the teacher network into a student network.
Preliminary knowledge
We use a classification model to introduce the knowledge distillation process. In the classification model, cross-entropy loss is used as the loss function. The following describes the derivation process of the cross-entropy loss function.
Cross entropy loss function derivation process
hard target与soft target
After understanding the cross-entropy loss function, here we need to introduce a concept, namely hard target
and soft target
, taking the classification model as an example:
each sample has a certain category, that is, its value is either 0 or 1, and such a label is is hard target
(indicated by the red box), and soft target
it is expressed in the form of probability, which is also the output result of our classification network (indicated by the green box). It is unscientific in terms of process.
For example, based on the following three classifications, hard target
it clearly tells us that a certain target in the picture is a horse, and he is neither a cart nor a donkey.
However soft target
, it gives the probability that the target is the three, such as horse 0.7, donkey 0.2, and cart 0.1. This method not only gives the correct category, but also gives the relative difference between the incorrect category, that is, the donkey and the horse are still the same. Some are similar, but the car is very similar, compared to hard target
this method contains more information.
Handwritten digit recognition classification as shown in the figure below: sotf target
It contains a lot of information, who is more like, and who is less like, so it is more scientific thansoft target
comparison . hard target
We can use teacher
the trained network sotf target
to send it to the student network as labels for learning.
To sum it up in one sentence:
Compared with the hard target, the soft target contains more knowledge and information, who is more like, less like, how similar, how unlike, especially the relative size of the probability that it can give the incorrect category
In order to make the relative error between categories more obvious, the concept of distillation temperature is introduced.
Distillation temperature
How to understand the distillation temperature? In fact, it is very simple. softmax
The modulation temperature coefficient T is added to the calculation of the function to amplify the difference. The definition is as follows: But T cannot be too large, and the gap between rich and poor will become smaller
The specific calculation is as follows, and the loss of the original softmax and T=3 are calculated respectively.
Knowledge Distillation Model Architecture
The specific process:
Send a large amount of data into the trained teacher network, find out T=t
the time soft labels(soft target)
, and then send the data to an untrained or semi-finished student network to calculate the time T=t
, soft prediction
and then soft labels
calculate the loss, and at the same time in the student network The prediction result when T=1 is also required, which is called hard prediction
, and its hard target
loss, and the last two losses are weighted and summed to obtain the final loss.
Regarding the content of learning:
it can be the final output result ( soft target
), it can also be the middle feature layer, or it can be the attention map.
Why is knowledge distillation useful?
The ellipse is the convergent domain space. It can be seen that the domain space converged by the teacher network is very similar to the domain space converged by the student network itself.