KDGAN: Knowledge distillation with generative adversarial networks paper notes

Paper address: http://papers.nips.cc/paper/7358-kdgan-knowledge-distillation-with-generative-adversarial-networks.pdf
github address: https://github.com/xiaojiew1/KDGAN/

Motivation

When training a lightweight classifier, knowledge distillation can converge with only a small number of samples and training times, but it is difficult to learn the real data distribution (real data) from the teacher. Another method is to classify the classifier through GAN. Adversarial training learns the true distribution of the data, but takes a long time to reach equilibrium due to high-variance gradient updates. In order to solve the above limitations, this paper proposes the framework of KDGAN, which consists of a classifier (student net), a teacher net and a discriminator. The classifier and teacher learn from each other through distillation loss, and the classifier is adversarially trained through adversarial loss.

Method

The article takes the multi-label classification task as an example to carry out research. The framework of KDGAN is shown in the figure. In addition to the distillation loss from the teacher net to the classifier in KD and the confrontation loss between the classifier and the discriminator in NaGAN (naive gan), the author also defines the distillation loss from the classifier to the teacher net and between the teacher net and the discriminator. Adversarial loss. That is, both the classifier and the teacher net are used as generators, and the generated labels are regarded as false by the discriminator. At the same time, the classifier and teacher net learn each other's knowledge by distilling soft labels from each other, so as to agree on what pseudo-labels to generate.

KDGAN

In order to speed up the training of KDGAN, on the one hand, the author empirically believes that the variance of the gradient received by the classifier from the teacher’s gradient will be smaller than the variance of the discriminator’s gradient, so the weighted average is smaller than the original gradient variance of GAN training, so that it can Converge quickly. On the one hand, since the discrete samples generated by the classifier and teacher are not differentiable, the author uses the Gumbel-Max technique to transform the distribution of discrete samples into a continuous distribution. Thus, the gradient value can be passed.

The specific algorithm steps of the model are as follows: Among them, the three parts need to be pre-trained, and then the three parts are updated several times in each epoch.
training

Experiment

KD loss: L2loss, KL Divergence
Number of experiments: 10
Application scenarios: model compression, image label recommendation
Datasets: MNIST, CIFAR-10, YFCC100M

Results

MNIST
Hyperparameters
YFCC100M

Thoughts

It is not the same as the KDGAN I imagined, so it is necessary to reproduce it.

Guess you like

Origin blog.csdn.net/qq_43812519/article/details/105474815