DeiT:Training data-efficient image transformers & distillation through attention

insert image description here
This article mainly uses some training strategies 知识蒸馏to improve the training speed and performance of the model.

Original text link: Training data-efficient image transformers & distillation through attention
Source code address: https://github.com/facebookresearch/deit
A good article written: Transformer learning (4) - DeiT
knowledge distillation can simply read this introduction: Knowledge Distillation (Knowledge Distillation) classic, paper notes

Training data-efficient image transformers & distillation through attention[PMLR2021]

Abstract

Although ViT has very high performance in classification tasks, it uses a large infrastructure to pre-train hundreds of millions of images to get the current effect. These two conditions limit its application.

Therefore, the author proposes a method 新的训练策略that uses only one computer to train a competitive non-convolution Transformer on ImageNet in less than 3 days. With no external data, the highest accuracy of 83.1% ((86M parameters)) is achieved on ImageNet.

Second, the author proposes a 基于知识蒸馏的策略. Relying on a distillation token to ensure that the student model learns from the teacher model through attention, usually the teacher model is based on convolution. The learned Transformer is competitive with the state-of-the-art on ImageNet (85.2%).

1 Introduction

Recently, there has been increasing interest in architectures that exploit the attention mechanism in convnets, proposing hybrid architectures that port the transformer components to ConvNets to solve vision tasks. In this article, the author uses a pure Transformer structure, but in the knowledge distillation strategy, the convnet network is used as the teacher network for training, which can inherit the inductive bias in the convnet.

The ViT model uses a large private labeled image dataset containing 300 million images to achieve the best results, and it is also concluded that it cannot generalize well when trained with insufficient data.

In this paper, the authors train a Vision Transformer in two to three days (53 hours of pre-training, and optionally 20 hours of fine-tuning) on ​​an 8GPU node, which is comparable to a ConvNet with a similar number of parameters and efficiency. compete. Use Imagenet as the only training set .

The author also used a token-based distillation strategy when extracting the model, in which ⚗ is used as the distillation symbol.

In summary, the following contributions are made:

  1. In the network 不包含卷积层, competitive results compared to the state-of-the-art can be achieved on ImageNet without external data. Two new model variants, DeiT-S and DeiT-Ti, have fewer parameters and can be seen as counterparts of ResNet-50 and ResNet-18.
  2. A new distillation process based on distilled tokens is introduced, which has the same role as class tokens, except that its purpose is to reproduce the teacher network 估计标签. These two tokens interact in the transformer through attention.
  3. Models pre-learned on Imagenet 转移到不同的下游任务are competitive at .

2 Method

Training strategy:
Train at lower resolutions and fine-tune the network at larger resolutions, which speeds up full training and improves accuracy under mainstream data augmentation schemes.

When increasing the resolution of the input image, the patch size is kept constant, so the number N of input patches changes. Due to the architecture of transformer blocks and class tokens, there is no need to modify the model and classifier to handle more tokens. Instead, positional embeddings need to be adjusted, since there are N positional embeddings, one for each patch.

Distillation: We first assume that a powerful image classifier can be used as a teacher model. It can be a convnet or a mixture of classifiers. This section introduces: hard distillation and soft distillation, and token distillation.

First of all, borrow a picture of Zhihu Little Emperor . The teacher model is a known model with a larger volume and superior effect. During the distillation process, the teacher model is not trained, it is only used as a guidepost to guide the image to find the teacher The parameters in the model correspond to our needs. In fact, we just use other information generated during the teacher model mapping process. In ordinary classification model training, the only information we have is the image and the classification label . If it is of this class, it is 1, and if it is not, it is 0. However, during the training process of the teacher model, the probability of different categories is obtained through the softmax function. We use this probability distribution to train the student model. In addition to positive samples, negative samples also contain a lot of information, but Ground Truth cannot provide this part of information. The probability distribution of the teacher model is equivalent to adding some new label information during the training of the student model. For more detailed content, you can see this link: Knowledge Distillation
insert image description here

1. Soft distillation:
Minimize the Kullback-Leibler divergence between the softmax of the teacher model and the softmax of the student model . Suppose Zt is the logits of the teacher model and Zs is the logits of the student model. Let τ denote the distillation temperature, λ denote the coefficients of Kullback-Leibler divergence loss (KL) and cross-entropy (LCE) on the balanced ground truth label y, and ψ denote the softmax function. The goal of distillation is : the normal loss calculation of part of y, and the second half is divergence.
insert image description here

2. Hard distillation variant:
output the prediction of the Teacher model yt = argmaxc Z t ( c ) y_t = argmax_cZ_t(c)yt=argmaxcZt( c ) As the true label, for a given image, the hard labelyt y_tytMay vary based on specific data additions. This option outperforms the traditional one while being parameter-free and conceptually simpler: the teacher predicts yt y_tytPlays the same role as the real label y. The distillation target is:

insert image description here

3. Distill token:

A new token, the distillation token , is added to the initial embedding (patch and class token) . The distillation token is similar to the class token: it interacts with other embeddings via self-attention and is output by the network after the last layer. Distilled embeddings allow the model to learn the predicted output of the teacher model, which not only learns the prior knowledge of the teacher model, but also complements the class embedding.

Distillation strategy:

  1. 微调:The ground truth labels and teacher predictions are used in the higher resolution fine-tuning stage. Using the teacher model with the same target resolution, only the true labels are used in the testing phase.
  2. 联合分类器. At test time, the class or distilled embeddings generated by the transformer are both associated with a linear classifier and are able to infer image labels. These two independent heads are fused at a later stage, and the softmax outputs of the two classifiers are added for prediction.

A comparison of the results of different distillation methods is written in the experimental section.

3 Conclusion

This article does not improve the ViT model itself , but uses some training strategies to make it easier to train and improve the performance of the model.

The core of this article is to use the strategy of knowledge distillation , which increases the prediction information of negative samples in the model training process, and inherits the inductive bias in the teacher model (the effect of convnet is better than that of Transformer), which is actually a kind of label information. Replenish.


Finally, I wish you all success in scientific research, good health, and success in everything~

Guess you like

Origin blog.csdn.net/qq_45122568/article/details/125566077