Detailed explanation of VIT transformer

I. Introduction

Paper: https://arxiv.org/abs/2010.11929
Original source code: GitHub - google-research/vision_transformer
Pytorch version code: https://github.com/lucidrains/vit-pytorch

Although the Transformer architecture has become the de-facto standard for NLP tasks, its application in CV is still limited. In vision, attention is either used in conjunction with ConvNets or used to replace some components of ConvNets while maintaining their overall structure. We demonstrate that this reliance on CNNs is unnecessary and that a pure Transformer applied directly to sequences of image patches can perform image classification tasks well. When pre-trained on large amounts of data and migrated to several small-to-medium image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.), the Vision Transformer (ViT) achieves superior results compared to state-of-the-art CNNs while only Requires fewer training resources.

2. Overall structure

insert image description here
The figure shows the structure of the entire ViT, which is divided into 5 parts:

  • Divide the picture into tokens
  • Convert Token to Token Embedding
  • Add the corresponding positions of Token Embedding and Position Embedding
  • Input into Transformer Encoder
  • Cls output to do multiple classification tasks

3. VIT input part

  1. The picture is divided into tokens
    , such as a 224x224x3 picture, divided into 16x16 tokens, each token is 14x14 in size, and 16x16x3=768 tokens are obtained.

  2. Convert Token to Token Embedding
    Straighten a 16x16x3=768 Token, pull it into a 1-dimensional vector with a length of 768,
    connect a Linear layer, and map 768 to the length of the Embedding Size (1024) specified by Transormer Encode

  3. Add the corresponding positions of Token Embedding and Position Embedding to
    generate a Token Embedding corresponding to Cls (corresponding to the * part in the figure)
    and generate the position codes of all sequences (including the Cls symbol and the position codes of all Token Embeddings corresponding to 0-9 in the figure)
    Add the corresponding positions of Token Embedding and Position Embedding

notice:
Question 1: Why add a Cls symbol?
Cls is a symbol used in Bert. Cls in NLP tasks can keep Bert's two tasks independent to a certain extent, and there is only one multi-classification task in ViT, so the Cls symbol is not necessary. and the author of the paper has also done experiments to prove this point. In the task of ViT, the training effect of adding Cls and not adding Cls is similar.

In the paper, the author wants to be as consistent as possible with the original Transformer, so class token is also used, because class token is also used in NLP classification tasks (also as a global feature of sentence understanding) , the class token in this article is to treat it as an overall feature of an image. After getting the output of this token, it is followed by an MLP (in MLP, tanh is used as a nonlinear activation function to make classification predictions)

The design of this class token is completely borrowed from NLP. It was not done in the visual field before. For example, there is a residual network Res50. In the last stage, a 14×14 feature map is produced, and then in this feature On top of the map, an operation called gap (global average pooling, global average pooling) is actually done. The features after pooling are actually straightened, which is a vector. At this time, this vector can be understood as a Global image features, and then use this feature for classification.

For Transformer, if there is a Transformer model, there are n elements in it, and there are n elements in it, why can't we just do global average pooling on n outputs to get a final feature, instead of adding one in front Class token, and finally use the output of class token for classification?

Through experiments, the author's final conclusion is: both methods are possible, that is to say, a global feature can be obtained through global average pooling and then classified, or a class token can be used to do it. All the experiments in this article are done with class token. The main purpose is to keep as close as possible with the original Transformer (stay as close as possible). The author does not want people to think that some good effects may be due to certain tricks or certain Some changes to cv are brought about. The author just wants to prove that a standard Transformer can still do vision.
![Insert picture description here](https://img-blog.csdnimg.cn/68249c6c44dc4e9d9b3c0f030503e3cd.png

Question 2: Why is positional encoding required?
RNN is calculated one by one, with a natural timing relationship, which can tell the model which words are in the front and which words are in the back; but in the Transformer, its words are input together, and then through the Attention layer, then Without positional encoding, the model does not know which words are in front and which are in the back. Therefore, we need to give the model a location information to tell the model which words/tokens are in front and which words/tokens are in the back.

The author has also done a lot of ablation experiments, mainly three kinds:

  • 1d: It is the position code commonly used in NLP, that is, the position code used in the paper from beginning to end
  • 2d: For example, in 1d, a picture is formatted into a nine-square grid, and numbers from 1 to 9 are used to represent image blocks. In 2d, 11, 12, 13, 21, etc. are used to represent image blocks, which is closer to visual problems. Because it has the overall structural information. The specific method is that the dimension of the original 1d position code is d, now because both the abscissa and ordinate need to be represented, the abscissa has a dimension of D/2, and the ordinate also has a dimension of D/2, that is to say, There is a D/2 vector to represent the abscissa and ordinate, and finally the two D/2 vectors are spliced ​​together to obtain a vector with a length of D, which is called a 2d position code
  • Relative positional embedding (relative positional encoding): In 1d positional encoding, the distance between two patches can be represented by either the absolute distance or the relative distance between them (the offset mentioned in the text ), which can also be considered as a way to represent the position information between image blocks

But the final result of this ablation experiment is also: the effect of the three representation methods is similar, as shown in the following figure:
insert image description here

4. VIT Encoder part

insert image description here
The difference between Transformer and Encoder in ViT:

  • Advance the Norm layer;
  • No Pad symbol;

Five, CLS multi-classification output

In the end, each token will get an output of 1024, and then the first (cls) 1024 vector is taken out, and then a fully connected layer is connected for multi-classification.

6. Inductive bias

Compared with CNN, vision transformer has a lot less image-specific inductive bias. For example, in CNN, locality (locality) and translate equivariance (translation equivariance) are reflected in each layer of the model. This Prior knowledge is equivalent to the whole model throughout

But for ViT, only the MLP layer is local and variable in translation, but the self-attention layer is global. The 2d information of this kind of picture is basically not used by ViT (only when the picture is cut into patches at the beginning) and position encoding, otherwise, no inductive bias for vision problems is used)

Moreover, the position encoding is actually initialized randomly at the beginning, and does not carry any 2D information. All the distance information between image blocks, scene information, etc. need to be learned from scratch.

Here is also a foreshadowing for the later results: the vision transformer does not use too much inductive bias, so it is understandable that the effect of pre-training on small and medium data sets is not as good as that of convolutional neural networks.

7. Reference

https://blog.csdn.net/qq_38253797/article/details/126085344

https://blog.csdn.net/SeptemberH/article/details/123730336

Guess you like

Origin blog.csdn.net/qq_54372122/article/details/131626883