Vision Transformer (vit) principle analysis and feature visualization

Table of contents

Introduction to Vit

Vit model structure diagram

vit input processing

Image tiling

Addition of class token and position

Feature extraction

vit code


Introduction to Vit

Vision Transformer (ViT) is a deep learning model based on the Transformer architecture for image recognition and computer vision tasks. Different from traditional convolutional neural networks (CNN), ViT directly regards the image as a serialized input and uses the self-attention mechanism to process the pixel relationship in the image.

ViT works by dividing the image into a series of patches and converting each patch into a vector representation as an input sequence . These vectors are then processed through a multi-layer Transformer encoder, which contains self-attention mechanism and feed-forward neural network layers . This captures contextual dependencies at different locations in the image . Finally, specific vision tasks can be accomplished by classifying or regressing the Transformer encoder output.

vit code reference: Neural network learning short record 67 - Detailed explanation of the recurrence of the Pytorch version of the Vision Transformer (VIT) model_vit recurrence_Bubbliiing's blog-CSDN blog

Why can't transformer be directly applied to image processing? This is because the transformer itself is used to process sequence tasks (such as NLP), but the image is two-dimensional or three-dimensional, and there is a certain structural relationship between pixels. If the transformer is simply applied to the image, pixels and pixels A certain correlation is required, so the amount of calculation is quite large. So vit was born.


Vit model structure diagram

The model structure of Vit is shown in the figure below. vit is to apply the image block to the transformer. CNN is based on the idea of ​​sliding window and uses convolution kernel to convolve on the image to obtain the feature map. In order to make the image imitate the input sequence of NLP, we can first divide the image into patches, then tile these image patches and input them into the network ( this becomes an image sequence ), and then perform feature extraction through the transformer , and finally classify these features through MLP [In fact, it can be understood as replacing backbone with transformer in previous CNN classification tasks].

Figure 1: Model overview. We split an image into fixed-size patches, linearly embed each of them, add position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder. In order to perform classification, we use the standard approach of adding an extra learnable “classification token” to the sequence. The illustration of the Transformer encoder was inspired by Vaswani et al. (2017).

vit input processing

Image tiling

Image blocking is the patch in the above vit diagram , and position Embedding is position embedding (the position information of the image block can be obtained) . So how to block images ? The simplest one can be achieved through convolution . We can control the resolution of the image block and how many blocks it is divided into by setting the convolution kernel size and step size.

How is it implemented in the code? You can see the code below.

class PatchEmbed(nn.Module):
    def __init__(self, input_shape = [224,224], patch_size = 16, in_channels = 3, num_features = 768, norm_layer = None, flatten = True):
        super().__init__()
        self.num_patch = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)  # 14*14的patch = 196
        self.flatten = flatten

        self.proj = nn.Conv2d(in_channels, num_features, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(num_features) if norm_layer else nn.Identity()

    def forward(self, x):
        x = self.proj(x)  # 先进行卷积 [1,3,224,224] ->[1,768,14,14]
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # x.flatten.transpose(1, 2) shape:[1,768,196]
        x = self.norm(x)
        return x

In the above code, num_patch is how many image blocks can be divided. proj is to perform convolution blocks on the input image and perform feature mapping. Assume that the input size is 1, 3, 224, 224. After the convolution operation, 1, 768, 14, 14 are obtained [indicating that 768 resolutions are formed through convolution Image block size is 14×14].

Each image block is extracted once. You can visualize one of the image blocks:

input image
input image
one of the image blocks

Then perform the flattening operation. It will become [1, 768, 196], and finally go through a layernorm layer to get the final output.

Visualize the input sequence

Addition of class token and position

Through the above operations, we can obtain the tiled feature sequence (shape is 1,768,196) . Then a class token will be added to the sequence, and this token will be sent to the network together with the previous feature sequence for feature extraction. The class token is 0* in the picture, so the original sequence with a length of 196 will become a sequence with a length of 197.

Then Position embedding will be added , which can add position information to all feature sequences . By generating a [197,768] matrix and adding it to the original feature sequence. At this point, the preprocessing patch+position embedding of the network input is completed.

# class token的定义
self.cls_token      = nn.Parameter(torch.zeros(1, 1, num_features))

# position embedding定义
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))

Code:

class VisionTransformer(nn.Module):
    def __init__(
            self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
            depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
            norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
        ):
        """
        输入大小为224,以16*16的卷积核分块14*14
        :param input_shape: 网络输入大小
        :param patch_size:  分块大小
        :param in_chans:  输入通道
        :param num_classes:  类别数量
        :param num_features: 特征图维度
        :param num_heads:  多头注意力机制head的数量
        :param mlp_ratio: MLP ratio
        :param qkv_bias: qkv的bias
        :param drop_rate: dropout rate
        :param norm_layer: layernorm
        :param act_layer: 激活函数
        """
        super().__init__()
        #-----------------------------------------------#
        #   224, 224, 3 -> 196, 768
        #-----------------------------------------------#
        self.patch_embed    = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_channels=in_chans, num_features=num_features)
        num_patches         = (224 // patch_size) * (224 // patch_size)
        self.num_features   = num_features
        self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
        self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]

        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))  # shape [1,1,768]
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))  # shape [1, 197, 768]

    def forward_features(self, x):
        x = self.patch_embed(x)  # 先分块 [1,196, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # [1,1,768]
        x = torch.cat((cls_token, x), dim=1)  # [1,197,768]
        cls_token_pe = self.pos_embed[:, 0:1, :]  # 获取class token pos_embed 【类位置信息】
        img_token_pe = self.pos_embed[:, 1:, :]  # 后196维度是图像特征的位置信息

        img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)  # [1,768,14,14]
        img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
        img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)  # [1,196,768]
        pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)  # [1,197,768]

        x = self.pos_drop(x + pos_embed)

Feature extraction

Like the CNN network, a backbone is also needed for feature extraction. In vit, the transformer encoder is used for feature extraction.

Our input is a sequence of [197,768], where 197 includes class token [learnable], image sequence and pos_embed [learnable]. This sequence is input into our encoder for feature extraction, and the important component of feature extraction in the transformer is multi-head attention.

In the picture above, you can see that the input image first passes through the Norm layer, and then is divided into three parts. These three parts are q, k, v , and then input into the multi-head attention mechanism at the same time, which is the self-attention mechanism. Then add the input to the residual side, and then output it through Norm and MLP.

q is our query sequence. The multiplication of q and k is to get the correlation, or importance, between each query vector in q and the feature vector in k. Then we multiply it by the original input vector v to get the contribution of each sequence [actually somewhat similar to the channel attention mechanism].

Extract features by building many self-attentions. If compared with CNN, the basic component unit of transformer is self-attention, and the basic component unit of CNN is convolution kernel.

Self-attention mechanism code:

qkv in code:

# Geometric meaning: q, k, v are distributed in num_heads heads (each head has qkv), and there are 197*64 feature sequences on each head.

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads  = num_heads
        self.scale      = (dim // num_heads) ** -0.5

        self.qkv        = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop  = nn.Dropout(attn_drop)
        self.proj       = nn.Linear(dim, dim)
        self.proj_drop  = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C     = x.shape
        qkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v     = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In the same way, we can use qkv to visualize features to see what is in q, k, v. The shapes of our q, k, and v are the same, and the shape is [batch_size, num_heads, 197, 768//num_heads]. We visualize the input of the first head of q (there are 12 heads in this picture, 64 features are extracted on each head).

The q feature vector on the first head

Then let’s look at the characteristics of k on the first head.

Characteristic vector of k on the first head

Then the attention weight is obtained by multiplying the q and k matrices

ps: Why are the results of q and k in the code q @ k.transpose(-2,-1)? Why not kT? This is because when we do two more matrix calculations, we only need to do the calculations on the last two dimensions. The dot product operation between q and k of each head needs to be performed on the last two dimensions. So just swap the last two dimensions.

The attention feature map obtained by matrix multiplication of q and k is as follows. We still only visualize the first head [there are 12 heads in total, and the attention feature map of each head is different]:

Then use sofmax to calculate the attention scores on all heads

The attention score of the first head is:

tensor([[9.9350e-01, 2.5650e-05, 2.6444e-05,  ..., 3.7445e-05, 3.3614e-05,
         2.7365e-05],
        [3.7948e-01, 2.3743e-01, 8.7877e-02,  ..., 2.2976e-05, 1.2177e-04,
         6.6991e-04],
        [3.7756e-01, 1.2583e-01, 1.4249e-01,  ..., 1.0860e-05, 3.4743e-05,
         1.1384e-04],
        ...,
        [4.1151e-01, 3.6945e-05, 9.8513e-06,  ..., 1.5886e-01, 1.1042e-01,
         4.4855e-02],
        [4.0967e-01, 1.7754e-04, 2.8480e-05,  ..., 1.0884e-01, 1.4333e-01,
         1.2111e-01],
        [4.1888e-01, 6.8779e-04, 6.7465e-05,  ..., 3.5659e-02, 9.4098e-02,
         2.2174e-01]], device='cuda:0')

The obtained attention score is then multiplied by v to obtain the contribution of each channel. 

Then add the MLP layer, and finally you can get our Transformer Block. 

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs = (drop, drop)

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
    def forward(self, x):
        '''
        :param x: 输入序列
        x --> layer_norm --> mulit head attention --> + --> x --> layer_norm --> mlp --> +-->x
        |____________________________________________|     |_____________________________|

        '''
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

vit code

GitHub - YINYIPENG-EN/vit_classification_pytorch: Use vit to implement image classification . Contribute to YINYIPENG-EN/vit_classification_pytorch development by creating an account on GitHub. icon-default.png?t=N7T8https://github.com/YINYIPENG-EN/vit_classification_pytorch.git

class VisionTransformer(nn.Module):
    def __init__(
            self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,
            depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
            norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
        ):
        """
        输入大小为224,以16*16的卷积核分块14*14
        :param input_shape: 网络输入大小
        :param patch_size:  分块大小
        :param in_chans:  输入通道
        :param num_classes:  类别数量
        :param num_features: 特征图维度
        :param num_heads:  多头注意力机制head的数量
        :param mlp_ratio: MLP ratio
        :param qkv_bias: qkv的bias
        :param drop_rate: dropout rate
        :param norm_layer: layernorm
        :param act_layer: 激活函数
        """
        super().__init__()
        #-----------------------------------------------#
        #   224, 224, 3 -> 196, 768
        #-----------------------------------------------#
        self.patch_embed    = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_channels=in_chans, num_features=num_features)
        num_patches         = (224 // patch_size) * (224 // patch_size)
        self.num_features   = num_features
        self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
        self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]

        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))  # shape [1,1,768]
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))  # shape [1, 197, 768]

        # -----------------------------------------------#
        #   197, 768 -> 197, 768  12次
        # -----------------------------------------------#
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.Sequential(
            *[
                Block(
                    dim=num_features,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    act_layer=act_layer
                ) for i in range(depth)
            ]
        )
        self.norm = norm_layer(num_features)
        self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()


    def forward_features(self, x):
        x = self.patch_embed(x)  # 先分块 [1,196, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # [1,1,768]
        x = torch.cat((cls_token, x), dim=1)  # [1,197,768]
        cls_token_pe = self.pos_embed[:, 0:1, :]  # 获取class token pos_embed 【类位置信息】
        img_token_pe = self.pos_embed[:, 1:, :]  # 后196维度是图像特征的位置信息

        img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)  # [1,768,14,14]
        img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
        img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)  # [1,196,768]
        pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)  # [1,197,768] 获得最终的位置信息

        x = self.pos_drop(x + pos_embed)  # 将位置信息和图像序列相加

        x = self.blocks(x)  # 特征提取
        x = self.norm(x)
        return x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def freeze_backbone(self):
        backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
        for module in backbone:
            try:
                for param in module.parameters():
                    param.requires_grad = False
            except:
                module.requires_grad = False

    def Unfreeze_backbone(self):
        backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
        for module in backbone:
            try:
                for param in module.parameters():
                    param.requires_grad = True
            except:
                module.requires_grad = True

Guess you like

Origin blog.csdn.net/z240626191s/article/details/132504292