[Lightweight Deep Learning] Combination of Knowledge Distillation and NLP Language Model

Knowledge Distillation

Student : Wenxuan Zeng

School : University of Electronic Science and Technology of China

Date : 2022.3.25 - 2022.4.3



Reference paper: Distilling the Knowledge in a Neural Network

This paper is the pioneering work of knowledge distillation , published in NIPS'14, it is very worthy of our study and research. So I started with this paper to learn knowledge distillation, and then to learn how to use knowledge distillation to compress the BERT model.
insert image description here

1 Definition of Knowledge

If knowledge is the parameters in the model, it will be difficult to transfer, because two different models do not have one-to-one corresponding parameters. The relative size of each category probability in the prediction result of the teacher network implicitly contains knowledge , which is also called the mapping from the input vector to the output vector in this paper. To give an intuitive example, for a picture of a car, the model will give the predicted probability of all objects. For example, a part of the probability is a bus, and a reduced part of the probability is a carrot. Then the teacher network can teach students such knowledge as the network ——This picture is more likely to be a car than a bus or a carrot, and this picture is more like a bus than a carrot. In effect, it is to show that knowledge contains both correct and incorrect information relative to each other.

2 Soft targets

insert image description here

One way to transfer the generalization ability of the heavy model to the small model is to use the class probabilities generated by the heavy model as soft targets to train the small model . Soft targets contain higher entropy, so more detailed information is provided; while hard targets (one-hot encoding) have low entropy and provide less information.

What are Soft/Hard targets? For example, in a three-category problem, the hard target of a car may be expressed like this: (0, 0, 1), and the soft target may be like this: (0.1, 0.3, 0.6). Obviously, soft targets contain more information, such as the relative information similar to "this picture is more like a bus than a carrot" mentioned earlier.

Hard loss:
L h a r d = − 1 N ∑ i = 1 N l o g ( P ( x i ) ) L_{hard}=-\frac{1}{N}\sum^N_{i=1} log(P(x_i)) Lhard=N1i=1Nlog(P(xi))
Soft loss:
L s o f t = − 1 N ∑ i = 1 N ∑ j = 1 N y i j l o g ( P ( x i j ) ) L_{soft}=-\frac{1}{N}\sum^N_{i=1}\sum^N_{j=1} y_{ij} log(P(x_{ij})) Lsoft=N1i=1Nj=1Nyijlog(P(xij))

3 T-Softmax

Review the role of softmax. When doing classification tasks, the probability of all categories is compressed to the range of [0,1] through softmax, and the sum of the probability values ​​​​is 1. The Softmax expression is as follows:

q i = e x p ( z i ) ∑ j e x p ( z j ) q_i=\frac{exp(z_i)}{\sum_j exp(z_j)} qi=jexp(zj)exp(zi)
T-Softmax is based on softmax, let each input z be removed by T, as follows:

q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_i=\frac{exp(z_i/T)}{\sum_j exp(z_j/T)} qi=jexp(zj/T)exp(zi/T)
T here is the distillation temperature. When T=1, it is softmax. When T>1, soft targets are obtained.

Conclusion: The larger T is, the softer the prediction result is, and the closer the probability values ​​of each category are, so the more knowledge it contains.

insert image description here

4 Knowledge Distillation

4.1 Distillation process

The figure below shows the process of knowledge distillation. The teacher network is trained at a temperature of t to obtain soft labels. The student network is trained at a temperature of t to obtain soft predictions. By fitting the soft labels and soft predictions, the students are guided to learn from the teacher network The knowledge learned from the Internet** (the analogy of soft labels is the teacher's precepts and deeds) . In addition, the student network is trained when the temperature is 1 to obtain hard prediction, that is, one-hot encoding, and then use the cross-entropy loss function and hard label to calculate the student loss (analogously hard label is textbook knowledge)** .

insert image description here

insert image description here

4.2 Loss function

L = γ L h a r d + ( 1 − γ ) T 2 L s o f t L = \gamma L_{hard} + (1-\gamma)T^2 L_{soft} L=γLhard+(1c ) T2 Lsoft

Note that you need to multiply T 2 T^2 at the soft lossT2. Changing the temperature used for distillation, the relative contributions of hard and soft targets remained roughly the same.

insert image description here

4.3 Predicted value matching is a special form of knowledge distillation

In the paper Model Compression (SIGKDD'06) , the author realized the compression of the model through knowledge transfer. In detail, the logits of the teacher network and the student network were used to obtain the MSE. In this article, the authors say that this compression method is a special case of distillation.

∂ C ∂ z i = 1 T ( q i − p i ) = 1 T ( e z i / T ∑ j e z j / T − e v i / T ∑ j e v j / T ) \frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T} (\frac{e^{z_i/T}}{\sum_j e^{z_j/T}} - \frac{e^{v_i}/T}{\sum_j e^{v_j}/T}) ziC=T1(qipi)=T1(jezj/Tezi/Tjevj/Tevi/T)
of which,qi q_iqiis the posterior probability predicted by the student network, pi p_ipiis the posterior probability predicted by the teacher network.

Assuming that the distillation temperature T is high enough, then according to Taylor expansion: ex = 1 + xe^x=1+xex=1+x

∂ C ∂ z i = 1 T ( q i − p i ) = 1 T ( 1 + z i / T N + ∑ j z j / T − 1 + v i / T N + ∑ j v j / T ) \frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T}(\frac{1+{z_i}/T}{N+\sum_j {z_j}/T}-\frac{1+{v_i}/T}{N+\sum_j {v_j}/T}) ziC=T1(qipi)=T1(N+jzj/T1+zi/TN+jvj/T1+vi/T)
Assuming that the logits expectation for different samples is 0, then

∂ C ∂ z i ≈ 1 N T 2 ( z i − v i ) \frac{\partial C}{\partial z_i} \approx \frac{1}{NT^2}(z_i-v_i) ziCNT21(zivi)
In summary, if the distillation temperature is high enough and the expectation of logits is 0, then knowledge distillation is equivalent to minimizingMSE = 1 / 2 ( zi − vi ) 2 MSE=1/2(z_i-v_i)^2MSE=1/2(zivi)2​。

But in reality, the temperature cannot be infinite. As can be seen from the figure below, when the temperature is too low, the softmax value corresponding to the small logits is suppressed to 0, and there is no right to speak, and the effect of distillation cannot be exerted; when the temperature is too high, the probabilities of all categories converge, which may bring Come noise.
insert image description here

It is better to use a higher temperature T. This needs to be determined by experience. Generally speaking, the middle temperature has the best effect.

insert image description here

4.4 Simple calculation of knowledge distillation

insert image description here

5 Experimental Design

It is amazing that the student network can achieve zero-sample learning . For example, the student network has not seen the translation invariance knowledge in CNN, but it can still be learned through the knowledge transfer of the teacher network. Erase the number 3 from the training of the student network, and the student network can still learn the characteristics of 3 from the knowledge of the teacher network (the author manually adjusted the bias).

6 Development Direction of Knowledge Distillation

  • Arrange teaching assistants, arrange multiple teachers, multiple students
  • Representation of Knowledge Representation (Middle Layer)
  • Distillation of multimodal, knowledge graph, and pre-trained large-scale models

7 Research on knowledge distillation in the field of NLP

In this part, I selected several very classic BERT distillation papers, and then studied the idea of ​​BERT distillation. The following are some of my learning records.

7.1 Distilled BiLSTM

链接: Distilling Task-Specific Knowledge from BERT into Simple Neural Networks

Method: The teacher model uses the fine-tune BERT-LARGE model, the student model uses BiLSTM+ReLU, and the distillation target is the cross entropy of the student model and hard labels + the MSE between the logits of BERT-LARGE.
insert image description here

7.2 BERT-PKD

链接: Patient knowledge distillation for bert model compression (ACL’19)

Method: Instead of distilling directly from the last layer of the model, knowledge is extracted from the middle layer of the teacher model for distillation . In this paper, two different distillation methods are proposed: the distillation method of the Skip-k layer and the distillation method of the last k layer.

insert image description here

The gap between the predicted values ​​of the student and teacher models is defined by the cross-entropy loss function:

insert image description here

In addition to letting the student imitate the teacher, a task-dependent cross-entropy loss function is defined:

insert image description here

In addition, the MSE loss of the standardized hidden state is also defined as the loss function:

insert image description here

7.2 DistillBERT

链接: Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter (NIPS’19)

Method: In the pre-training stage, knowledge distillation technology is used to compress BERT. In order to utilize the inductive bias learned from the large model during pre-training, a ternary loss that combines language modeling, distillation and cosine distance loss is introduced.

insert image description here

7.3 TinyBERT

链接:Tinybert: Distilling bert for natural language understanding (ACL’20)

Method: A two-stage learning framework is proposed, and the teacher model is distilled in the pre-training and fine-tune stages respectively, and a 4-layer BERT with a parameter reduction of 7.5 times and a speed increase of 9.4 times is obtained, and the effect can reach 96.8% of the teacher model. The 6-layer model trained by this method is even close to BERT-base, surpassing BERT-PKD and DistillBERT. This paper proposes the distillation of the attention matrix , using MSE as the loss function to fit the teacher and student attention matrices.

insert image description here

At the same time, knowledge distillation is performed on both the embedding layer and the hidden layer, and MSE is also used as the loss function:

insert image description here

Finally, use the cross-entropy loss function to measure the logits gap between the teacher and student models:
insert image description here

Based on the distillation goals mentioned above, according to the distillation layer, decide which distillation loss to use:

insert image description here

7.4 MobileBERT

链接: MobileBERT:a Compact Task-Agnostic BERT for Resource-Limited Devices (ACL’20)

Method: The bottleneck structure and the balance mechanism of self-attention and feed-forward neural network are adopted to distill the knowledge from the teacher model to the student model, so that the model has a narrower width. (Specific notes are written in the Paper Understanding section of the previous document)

insert image description here

7.5 MiniLM

链接: MiniLM: Deep Self-Attention Distillation for Task-Agnostic Compression of Pre-Trained Transformers (NIPS’20)

Method: Although the previous article distilled the model all over, from the embedded layer to the hidden layer, then to the attention layer, and finally to the prediction layer, this article still found a new point to distill and achieved very good results. This article distills the self-attention module and proposes the scaled dot-product (value-relation) between values ​​as new deep self-attention knowledge. In addition, this article uses a teacher assistant to assist the distillation of large models.

insert image description here

The relationship between self-attention matrices is measured by KL divergence:

insert image description here

The following is the value-relation defined in this article, which is actually a scaled dot product for the value, and then uses the KL divergence to measure the two VR matrices:

insert image description here

Conclusion: This article says that "only the last layer of distillation is better than layer-to-layer, and there is no need to strictly correspond to each layer of the two models. Only the last layer of distillation can also improve the performance of students and enable students to have stronger generalization ability".

8 My thoughts on knowledge distillation⭐

  • Through the previous papers, it can be found that some papers believe that patiently distilling the middle layer will bring good results, while some papers insist that only the last layer will bring good results. They all justify themselves in their papers, so I think how to adaptively select a specific layer for distillation is a direction worth thinking about .
  • Recently, I am also studying AutoLoss related papers in AutoML. Some ideas are to guide the student network (the model that completes the actual task) through the teacher network (the completion of the auxiliary task) to learn an optimal loss function. There are many ways to realize this idea. I think this idea of ​​"teachers guiding students to better complete the target tasks" is promising, or in other words, it is worth our further thinking.
  • In the papers I read, most of them are "knowledge transfer from teacher to student". I am thinking about whether I can improve from the perspective of "teaching and teaching". Give the teacher something, teach the teacher some truth? This approach seems to be described as "removing the teacher model and using two or more models to learn from and promote each other."

The previous summary was written before, and here is a supplement from two weeks later:
Later, I read that many papers are doing teacher-student co-teaching process, including introducing curriculum learning (course learning), generate pseudo labels (generating pseudo labels ), and some papers use technologies such as self-distillation, twin network mutual learning, multi-stage distillation, etc.
For the process of mutual learning and mutual distillation of the two models, some papers specifically proposed Mutual-Distillation and related methods; for the problem of adaptively selecting layers for distillation, there are also papers that specifically make an Attention-based distillation layer selection scheme , which will specific The semantics of the specific layer in the task are taken into account.

Guess you like

Origin blog.csdn.net/qq_16763983/article/details/124430975