Paper reading notes: Vision Transformer (ViT)

1. Vision Transformer

Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).

This is an article that laid the foundation for Transformer to defeat traditional convolution in the field of vision. After Transformer shined in the field of NLP, it also achieved excellent results in the field of vision. The author abandoned all convolution operations and divided the image into several patch, then encode it, and input it into the Transformer model like a text sequence. The effect achieved on medium-sized data sets is not as good as convolution, but the performance on large-scale data sets has been able to surpass convolution.

insert image description here
Assume that the size of a picture is H × W × CH \times W \times CH×W×C , the size of the patch isP × PP \times PP×P , then the divided picture can be expressed asN × ( P ​​2 × C ) N \times (P^2 \times C)N×(P2×C ) , whereN = ( H × W ) / P 2 N = (H \times W) / P^2N=(H×W)/P2 . Then the initial encoding length of a patch is equal to( P 2 × C ) (P^2 \times C)(P2×C ) After linear projection and position encoding, it can be like training text. In addition, as shown in the figure, there are nine patches input into the network, but it is difficult to decide which encoding result to use for image classification, so an additional cls_tokenfor classification is input into the network, and its dimension is the sameas The patch is consistent and we can consider it an artificially added patch for the final classification.

The attention mechanism is not used in image processing for the first time. The SE (sequeeze and excitation) block is actually an attention mechanism, but it acts on the channel dimension, while ViT acts on the global. Each patch can do attention with any channel patch. In fact, when the shape of the patch is 1x1, the effect is very similar to the SE block.

insert image description here
The experimental parameter settings are shown in the figure, and it can be seen that the number of parameters is large.

2. Code

import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torchvision
from torch.utils import data
import matplotlib.pyplot as plt
import copy
import math

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class PreNorm(nn.Module):

    def __init__(self, dim, fn):
        super(PreNorm, self).__init__()
        self.norm = nn.LayerNorm(normalized_shape=dim)
        self.fn = fn

    def forward(self, x):
        x = self.norm(x)
        x = self.fn(x)
        return x

class FeedForward(nn.Module):

    def __init__(self, dim, hidden_dim, dropout=0.):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features=dim, out_features=hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(in_features=hidden_dim, out_features=dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = self.net(x)
        return x

class Attention(nn.Module):

    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super(Attention, self).__init__()
        inner_dim = heads * dim_head
        project_out = not(heads == 1 and dim_head == dim)
        self.heads = heads
        self.scale = dim_head ** (-0.5)

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim*3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # x [batch_size, 查询个数, dim]
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t:rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)  # q,k,v维度相等 [batch_size, num_heads, 查询个数, d]
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):

    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super(Transformer, self).__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList([
                    PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                    PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
                ])
            )

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):

    def __init__(self,image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 1, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super(ViT, self).__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {
    
    'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )  # h * w 等于patch的数量

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # img [batch_size, c, H, W]
        x = self.to_patch_embedding(img)  # [batch_size, num_patch, dim]
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)  # [batch_size, 1, dim]
        x = torch.cat((cls_tokens, x), dim=1)  # [batch_size, 1 + num_patch, dim]
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)

net = ViT(image_size=(224, 224), patch_size=(32, 32), num_classes=10, dim=768, depth=12, heads=12, mlp_dim=3072)

Guess you like

Origin blog.csdn.net/loki2018/article/details/125002995