VIT: Vision Transformer super detailed explanation with code

Original paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

1. VIT model architecture diagram

In simple terms, the model consists of three modules:

(1) Linear Projection of Flattened Patches(Embedding层)

(2) Transformer Encoder

(3) MLP Head: The final layer structure for classification

Specific steps:

1.1 The picture is divided into patches

1.2 Convert patch to embedding

Since a patch is a square, it cannot be directly used as the input of TRM. It is necessary to convert this patch into a fixed-dimensional embedding, and then use the embedding as the input of TRM. Method 1: Flatten the patch and convert it from 2D to 1D (eg. the original 16x16 becomes 256); Method 2: Map the flattened dimension to a vector length specified by myself.

Note: There are two experimental methods in this process. Here, Linear Projection is a linear transformation. There is another one that means petch=16*16, which can be done with a 16*16 convolution with a step size of 16. To operate this, the convolution kernel is set to 768, and the output channel is 768, which means that 768 is converted into the dimension of the TRM Encoder.

1.3 Addition of position embedding and token sembedding

First generate the token emb of the CLS symbol, * in the figure, and then generate the position codes of all sequences, 1, 2, 3... in the figure, pink and purple are added to get the input embadding.

Why add a CLS symbol?

After the paper, it was proved that CLS has little effect. Its function is to reduce the changes to the original TRM model. The use of CLS in BERT is because BERT has two pre-training tasks, NSP (two classification) task: predict the next sentence; MLM: Predict the current word. If both tasks use pooling for loss, they will be repeated on some tokens. Using CLS keeps the two tasks relatively independent to a certain extent. But VIT does not involve tasks in the form of MLM, there will only be a multi-category task, so the CLS symbol is not necessary.

location code

In order to maintain the spatial position information between the input image patches, it is also necessary to add a position encoding vector to the image block embedding. As shown in Epos in the above formula, the position encoding of ViT does not use the updated 2D position embedding method, but directly uses The one-dimensional learnable position embedding variable of the original paper author found that 2D did not show better results than 1D in actual use.

1.4 Input to TRM model

After the input, go through a Normalization layer, enter the self-attention layer, make a residual between the output and the input, input to the Normalization, input to the feed-forward neural network, and pass a residual, and do it several times if there are several Encoders. An output is generated for each token that is eventually obtained.

1.5 CLS output for multiple classification tasks

Take out the first CLS output to do multi-classification tasks

2. Code

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
#  多头注意力机制实现
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)  # 把dim维度映射到inner_dim * 3这个维度

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)  # 对tensor张量分块 x :1*197*1024   qkv 最后是一个元组,tuple,长度是3,每个元素形状:1 197 1024
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)  # 乘以对应的v矩阵
        out = rearrange(out, 'b h n d -> b n (h d)')  # 做一个形状的变化
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        # 把多个encoder堆叠在一起
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),  # 多头注意力机制
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))  # 前馈神经网络
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

#  整体架构
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size) ## 224*224
        patch_height, patch_width = pair(patch_size)## 16 * 16

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)  # 图片分割成多少个patch
        patch_dim = channels * patch_height * patch_width  # 拉平:patch的宽和高乘通道数
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        # 图片拉平映射到encoder我们自己规定的模型里
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 生成所有位置编码
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))   # 生成CLS token的初始化参数
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # 输入解决了之后,把它放到TRM中

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)  # img betch:1 通道3 224 224  输出形状x : 1*196*1024
        b, n, _ = x.shape ## 

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)  # 复制b份,每一个betchsize都要加一个CLS符号
        x = torch.cat((cls_tokens, x), dim=1)  # 把CLS的tokens Embedding 和Patch Embedding进行拼接
        x += self.pos_embedding[:, :(n + 1)]  # 相加
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]  

        x = self.to_latent(x)
        return self.mlp_head(x)



v = ViT(
    image_size = 224,  # 输入图像大小
    patch_size = 16,  # 切分的每一块的大小
    num_classes = 1000,  # 最后CLS拿出来的映射到多少个维度上,类别上
    dim = 1024,
    depth = 6,  # encoder层数
    heads = 16,  # 多头注意力机制参数
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

3. To summarize

I want to summarize, but the summary is not easy to come out, that is, this thing is very simple and feels nothing, but I don’t quite understand it, so I will send it out first, and I will change it when I have the opportunity.
 

Guess you like

Origin blog.csdn.net/Zosse/article/details/125690167