Vision Transformer(ViT)


论文链接: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

一、ViT整体结构

ViT总体结构

结构简单说明

首先关注ViT的输入
一张图片会被分成一个个小的patch,如ViT-L/16表示每个patch大小为16×16,然后将每个patch输入到Embedding层(Linear Projection of Flattened Patches),通过Embedding层后可以得到其对应的向量,称为token,图中一张图片划分为9个patches,在经过Embedding层后得到了9个token embedding。
紧接着,我们会在这一系列token的最前面加上一个新的token *(class token),它的维度与前面得到的是一样的。
原始的token加入class token与位置信息后,将其输入到transformer encoder。

输入部分
transformer encoder的结构如下:
transformer encoder
ViT的transformer encoder的操作是将Encoder Block重复堆叠了L次,然后提取class token对应的输出输入到如下的MLP Head中进行分类,最后得到分类结果。
MLP Head输出

二、ViT分解说明

根据上面的阐述,可以看出整个ViT可以分为三大部分

  • Embedding层
  • Encoder
  • MLP Head 用于分类

处理流程为

  1. 将图片切分为patch
  2. patch转换为embedding
  3. 位置embedding和token embedding相加
  4. 输入到ViT模型
  5. CLS输出做多分类

Embedding层

对于标准的transformer模块,它接收的是token embedding向量,变化过程如下图1、2、3标注

embedding操作
对于编码部分,共有三个操作

  1. 生成class符号的token(图中*标记)
  2. 生成所有序列的位置编码(图中淡紫色)
  3. token embedding + 位置编码

图中首先将原始图片变换为多个patch,每个patch大小为3×16×16。再将其展平为token embedding,维度为768,patch转换为embedding需要两个操作:

  • 将patch拉平
  • 将patch拉平后的维度映射到 encoder需要的维度

在这一系列embedding的首部加入cls token,然后生成位置编码,并将位置编码与token embedding相加得到最终的输入embedding。

关于位置编码:
在transformer中,编码器是并行输入的,不会等待之前信息的输出情况,所以需要位置编码提供信息的位置信息,在ViT中,表示图像patch的前后信息。

Encoder

ViT的Encoder模块与原始transformer中的类似。
根据论文中Encoder,结合具体实现可得出Encoder Block

Encoder详细结构
与原始transformer的Encoder输入比较

transformer Encoder输入
将LN操作提前了,同时,因为将图片切分为patches,保证patch的大小一致,所以没有了padding操作。

MLP Head

MLP Head结合代码实现来理解会清晰很多

class Mlp(nn.Module):
   """
   MLP as used in Vision Transformer, MLP-Mixer and related networks
   """
   def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
       super().__init__()
       out_features = out_features or in_features
       hidden_features = hidden_features or in_features
       self.fc1 = nn.Linear(in_features, hidden_features)
       self.act = act_layer()
       self.fc2 = nn.Linear(hidden_features, out_features)
       self.drop = nn.Dropout(drop)

   def forward(self, x):
       x = self.fc1(x)
       x = self.act(x)
       x = self.drop(x)
       x = self.fc2(x)
       x = self.drop(x)
       return x

可以看到,MLP仅由GELU激活函数、全连接层和DropOut层组成,作用就是对Encoder的输出进行多分类处理。

三、ViT简洁实现

对几个关键模块结构进行解释。

Attention

attention模块与transformer类似,实现多头注意力multi head机制,在forward函数中,通过to_qkv和chunk函数一次生成总体的Q、K、V矩阵,再划分为多头注意力的q、k、v,这一点与原始transformer不同,原始transformer是通过Linear层各自生成Q、K、V,这个差别的原因在于ViT无需解码。

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 对tensor张量分块 eg. x :1 197 1024
        # 通过to_qkv操作将维度提升至原维度的3倍
        # qkv 最后是一个元组,tuple,长度是3,每个元素形状:1 197 1024
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        print('qkv is ', qkv)
        # 将q,k,v矩阵分头(multi head)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # attention计算
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        # 与V矩阵相乘
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

对比原始transformer的多头注意力机制:

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        # 输入进来的QKV是相等的,使用映射Linear做一个映射分别得到参数矩阵Wq, Wk,Wv
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
        self.linear = nn.Linear(n_heads * d_v, d_model)
        self.layer_norm = nn.LayerNorm(d_model)

transformer

遵循论文中架构,堆叠L个Encoder。
transformer encoder

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])

        # 堆叠depth个encoder
        for _ in range(depth):
            # 根据论文结构中搭建
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

ViT

ViT实现图片分割为patch、patch展平并添加位置编码,同时映射为输入embedding,并对各个模块进行组装,代码见下方完整代码部分。

完整代码

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # 对tensor张量分块 x :1 197 1024
        # 通过to_qkv操作将维度提升至原维度的3倍
        # qkv 最后是一个元组,tuple,长度是3,每个元素形状:1 197 1024
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        print('qkv is ', qkv)
        # 将q,k,v矩阵分头(multi head)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        # attention计算
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        # 与V矩阵相乘
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])

        # 堆叠depth个encoder
        for _ in range(depth):
            # 根据论文结构中搭建
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()

        # 原图像的大小
        image_height, image_width = pair(image_size) ## 224*224
        # patch的大小
        patch_height, patch_width = pair(patch_size)## 16 * 16

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        # 对应论文中提到的patch数目:num_patches=H*W/P^2
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # 对应论文中,将patch展平
        patch_dim = channels * patch_height * patch_width
        assert pool in {
    
    'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # 将拉平后的patch映射为Encoder需要的维度dim
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        # 生成位置编码,包括cls token和所有patch对应token的位置编码
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        # 生成cls token的初始化参数
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        # 完成最后的分类
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        # img 1 3 224 224 ——> 输出形状x : 1 196 1024
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        # 复制cls符号使每个batch_size都有一个cls符号
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        # 拼接cls_token与patch embedding
        x = torch.cat((cls_tokens, x), dim=1)
        # 拼接后每个token加上位置信息
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        # 若pool为mean,所有token输出池化;若为cls符号,取cls符号(切片第0个元素)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        # 全连接分类
        return self.mlp_head(x)



v = ViT(
    image_size = 224,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6, # Encoder堆叠个数
    heads = 16, # multi head的head数目
    mlp_dim = 2048, # feed forward维度
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

这里使用了einops库帮助解决张量的运算,提升张量运算的可读性,具体可参看以下文章或者官方文档来学习使用。

博文:einops 理解
官方文档:einops GitHub

这里只是一个简单的ViT模型帮助理解,实际训练的ViT要更复杂,可以参看此代码
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py

本文参考:
Vision Transformer详解

猜你喜欢

转载自blog.csdn.net/qq_41533576/article/details/121107247