ViT: What is the role of cls token in Vision transformer?

Zhihu: Vision Transformer super detailed interpretation (principle analysis + code interpretation) 

CSDN: understanding of cls_token and position_embed in vit

CSDN: Why did ViT introduce cls_token

CSDN: Some problems with special class tokens in ViT


Vision  Transformer surpasses CNN on some tasks, benefiting from the aggregation of global information. In the ViT paper, the author introduced a class token as a classification feature.

If there is no cls_token, which patch token do we use for classification?

According to the self-attention mechanism, each patch token aggregates global information to a certain extent, but mainly its own characteristics. The ViT paper also uses the average method of all tokens, which means that each patch contributes the same to the prediction, which seems unreasonable. In fact, the effect of doing this is basically the same as introducing cls_token.

ViT introduces the class token mechanism, the reasons are as follows:

The Transformer input is a series of patch embeddings, and the output is also a sequence of patch features of the same length, but in the end it must be summarized as a category judgment.

  • You can use avg pool to consider all the patch features to calculate the image feature, which is unreasonable as mentioned above. .
  • The author quotes a class token similar to a flag , and its output features can be classified by adding a linear classifier.
  • During training, the embedding of the class token is randomly initialized and added to the pos embedding. Therefore, it can be seen from the figure that a new embedding is added at [0] when the transformer is input, and the final input length is N+1.


The main features of the n+1th token (class embedding) are: not based on the image content; the position code is fixed.

Advantage:

  • The token is randomly initialized and updated as the network is trained, it can encode the statistical characteristics of the entire data set;
  • The token aggregates information on all other tokens (global feature aggregation), and since it is not based on image content, it can avoid bias towards a specific token in the sequence;
  • Using a fixed positional encoding for this token prevents the output from being disturbed by the positional encoding.

In ViT, the author regards class embedding as the head of the sequence rather than the tail, that is, the position is 0. In this way, even if the length n of the sequence changes, the position code of the class embedding is still fixed . Therefore, it is more accurate to say that the class embedding should be the 0th token instead of the n+1th token.

In addition, "Is it feasible to average the first n tokens as the features to be classified?" This is also a way of global feature aggregation, but it is less expressive than using the attention mechanism for global feature aggregation. Because the attention mechanism is used for feature aggregation, the weight of feature aggregation can be adaptively adjusted according to the relationship between query and key, while the average method is to give the same weight to all keys, which limits the model. expression ability.

The class token: A vector input into the Transformer block together with the input token, and the final output is used to predict the category. In this way, the Transformer is equivalent to processing a total of N+1 tokens with a dimension of D, and only the output of the last token is used to predict the category. This architecture forces information to be propagated between patch tokens and class tokens.

Legacy:

  • The class embedding should be the 0th token instead of the n+1th token.
  • Only the output of the last token is used to predict the class.
  • Are the two contradictory? How do you understand?

Guess you like

Origin blog.csdn.net/MengYa_Dream/article/details/126600748