[ViT Detailed Explanation] Vision Transformer Network Structure and Code Interpretation


Paper link: https://arxiv.org/abs/2010.11929
Source address: https://github.com/google-research/vision_transformer

foreword

This week, I followed Mushen to re-read ViT, a milestone document in the visual field. Every time I read it, I have some new understanding and thinking, so I record it here for reference.
The most important contribution of ViT is to apply the basic paradigm of Transformer to the cv field, and verify the operability of this idea through a series of experiments, leading to another important aspect of the cv field after convolutional neural network and graph neural network. The pound basic model has opened up a new way to establish a unified multi-modal model.

Invision, attention is either applied in conjunction with convolutional networks, or used to replace certain components of convolutional networks while keeping their overall structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks.

1 Network structure


The above picture is the ViT network structure diagram shown by the author in the original text. It can be seen that except for typing the picture into a patch for input, the rest of the operations are not much different from the NLP process. Next, we will mainly follow the above picture. Process Discuss the principle of ViT's image classification.

1.1 image2patch: convert image to image block

The network structure clearly shows that the input image is cropped into a series of patches (picture blocks), which become the original input of the network. Assuming that the original input image data is H x W x C, we need to cut the image into blocks to obtain the patch input. Assuming that the image block size is P1 x P2, the final number of blocks N is: N = (H/P1)x(W/P2). In this process, we need to pay attention to the following points:

  • H and W must be divisible by P1 and P2 respectively, otherwise the network will directly throw an exception.
  • Each patch will be flattened, from the original three-dimensional Cx(HxP1)x(WxP2)to two-dimensional (HxW)x(P1xP2xC). Among them HxW, it is equivalent to the token in the NLP word sequence, (P1xP2xC)which is equivalent to the maximum sequence dimension. This step is not only to reduce the high complexity consumption caused by the scale change of the input image, but also to adapt to the changes made by the Transformer architecture to the image, and we can see that the operations in the source code correspond to our understanding
# Rearrange函数使用爱因斯坦表达式进行维度转换,具体用法读者可自行查阅
self.image2patch =Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width)

1.2 patch_dim: dimension mapping

In some cases, we want to use a suitable dimension to describe each patch, which requires mapping the obtained (P1xP2xC), so ViT adds a fully connected layer here to scale the patch dimension:

self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),# 维度缩放
        )

1.3 position_embedding: Embedding position information

Consistent with the operation in NLP, when we input a sequence of words, in order to mark the position of each word in the sequence, we will use a specific formula to add position information to it, usually after obtaining the position information through a specific formula, it is directly compared with the original Word sequence addition, for example, position embedding in Transformer uses the following formula:

position embedding code comment:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        """
        :param d_model: pe编码维度,一般与word embedding相同,方便相加
        :param dropout: dorp out
        :param max_len: 语料库中最长句子的长度,即word embedding中的L
        """
        super(PositionalEncoding, self).__init__()
        # 定义drop out
        self.dropout = nn.Dropout(p=dropout)
        # 计算pe编码
        pe = torch.zeros(max_len, d_model) # 建立空表,每行代表一个词的位置,每列代表一个编码位
        position = torch.arange(0, max_len).unsqueeze(1) # 建个arrange表示词的位置以便公式计算,size=(max_len,1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *    # 计算公式中10000**(2i/d_model)
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # 计算偶数维度的pe值
        pe[:, 1::2] = torch.cos(position * div_term)  # 计算奇数维度的pe值
        pe = pe.unsqueeze(0)  # size=(1, L, d_model),为了后续与word_embedding相加,意为batch维度下的操作相同
        self.register_buffer('pe', pe)  # pe值是不参加训练的

    def forward(self, x):
        # 输入的最终编码 = word_embedding + positional_embedding
        x = x + Variable(self.pe[:, :x.size(1)],requires_grad=False) #size = [batch, L, d_model]
        return self.dropout(x) # size = [batch, L, d_model]

Through the position embedding method, we have realized the addition of position information to the word sequence, which is the meaning of Patch+Position Embedding in the network structure .

1.4 cls_token: embedded category information

The traditional Transformer adopts the form of Seq2Seq, but in the Vision Transformer we only simulate the encoding part, and the decoding part is missing, so in order to determine the unique classification head input, the author and others added a learnable cls token as the final The vector of the input classification header is spliced ​​with the original one-dimensional image block vector by concat, so its size is (1x1xdim). Since the classification input header is added separately here, the sequence length of the final input becomes (HxW+1)x(P1xP2xC).
The specific internal details of the above four operations are shown in the following figure:

1.5 Transformer Encoder

The main part of the network structure is the encoder. The network achieves a better classification effect by continuously stacking the encoders. Its internal structure is shown in the figure below: the input of embedded position

information and category information must first be processed by a Layer Norm, and then enter the Multi - Transformation before the Head Attention layer (the three vectors of Q, K, and V are generated, and the subsequent operations are the same as those of Transformer. When calculating QxK, we can regard the inner product of two vectors as calculating the association between picture blocks (similar to the calculation of word vector similarity in Transformer), after obtaining the attention weight, scale to V, and then obtain the output of the Encoder part through the MLP layer (multiple Encoder Block superpositions can be performed here, as shown in the figure above). Similar to Transformer, the significance of multiple heads is that it can promote the model to learn all-round, multi-level, and multi-angle information, and learn richer information features. For the same picture, there will be certain differences in the parts that everyone sees and notices. , and the bulls in the image just combine these differences for learning.

1.6 MLP

After the Transformer Encoder is over, we have come to our final classification processing part. When we performed the Encoder before, we added an additional learnable vector for classification through concat. At this time, we took this vector out and input it into the MLP Head. That is, after Layer Normal --> Full Connection --> GELU --> Full Connection, we get the final output.

2 code analysis

2.1 Library dependencies

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

2.2 Main structure

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {
    
    'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

2.3 Transformer

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

2.4 Attention

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

Guess you like

Origin blog.csdn.net/weixin_43427721/article/details/126608144