ViT笔记以及其Pytroch实现

ViT:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE——ICLR,2021

Pytroch代码来源:https://github.com/lucidrains/vit-pytorch

一、背景介绍

目前在NLP领域,transformer已经占据主导地位。不少学者尝试将attention和CNN相结合,这些方法往往依赖于CNN,其性能相较于常见的卷积网络如ResNet等还是有差别。

在NLP领域,使用transformer时,当不断增加模型大小和数据集数量,模型性能没有出现饱和趋势。同样的在CV领域,当数据量较小时,使用transformer有时并不比常见卷积性能好。但当数据集数量不断变大,transformer性能不断提高,甚至超过常见卷积模型。

二、方法介绍

本片论文介绍的方法主要是用来进行分类的。输入一张图片,输出特征。作者在结论中也说到,本文方法的一个挑战是如何把ViT应用到检测和分割等视觉任务上。作者的动机是尽可能减少Transformer原始结构的改变。

因为论文内容写的比较简单,而本篇博客主要是为了熟悉并学习如何使用ViT,故需要结合相关代码(代码为网上找的Pytroch版本,不是作者提供的源码,仅供参考)。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-O5jI5vB4-1680748335757)(D:\研究生文件\顶会顶刊论文代码\笔记\图片\image-20230404183418667.png)]

首先,本文结构如果熟悉Transformer的话是比较容易理解的。基本流程为先将图像大小为 x ∈ R H × W × C x\in R^{H\times W \times C} xRH×W×C裁剪成相同大小没有重叠部分的patch块, x p ∈ R ( N × ( P 2 ⋅ C ) ) x_p\in R^{(N \times (P^2 \cdot C))} xpR(N×(P2C)),其中 H W = N P 2 HW=NP^2 HW=NP2 ( P , P ) (P,P) (P,P)为裁剪出的patch大小。然后将patch块通过线性映射变成 x p a t c h ∈ R D x_{patch}\in R^D xpatchRD

这里的线性映射的代码如下,其中关于einops库中的rearrange相关介绍可以参考einops.rearrange

from einops.layers.torch import Rearrange

self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim), # dim=D
        )

接着和BERT做法一样添加[class] token x c l a s s x_{class} xclass x c l a s s x_{class} xclass就不需要通过线性映射层了,这是一个可学习参数,Pytroch中直接令self.cls_token = nn.Parameter(torch.randn(1, 1, dim)),其中 d i m = D dim=D dim=D

这里的 x c l a s s x_{class} xclass通过Transformer encoder之后得到 y y y就是最终的输出结果。

Transformer中用到了位置编码,这里作者使用了1D位置编码,因为作者通过实验发现使用2D位置编码,效果并未得到较大的提升。这里就比较有意思了,论文中作者说这个位置信息是学出来的,而不是一开始就给定的(像1,2,3,。。。这样的位置编码)。

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # 这里的位置编码采用的是可学习编码

接下来便是Transformer Encoder的结构,包含multiheaded self-attention(MSA)和MLP blocks。每一层中都添加了Layernorm(LN)层,并采用了residual connection。代码如下:

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

接下来便是ViT的整体实现了。可以看到,这个类的输入的batch里面均为整张图片。

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {
    
    'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] # 这里代码实现又两个选择,一个是取所有patch块提取特征的均值,另一个是和论文中的一样取[class] token

        x = self.to_latent(x)
        return self.mlp_head(x)

论文中给出的计算公式如下,其中 z L 0 z_L^0 zL0:表示的是 x c l a s s x_{class} xclass通过多层Transformer输出的特征。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WDWrl5TI-1680748335759)(D:\研究生文件\顶会顶刊论文代码\笔记\图片\image-20230406094210911.png)]

作者在这里还提出了一种混合框架(Hybrid Architecture),即输入的不是图像而是通过CNN提取的图像特征。

如果想要使用ViT作为特征提取器,可以把最后一层的mlp_head去掉,添加下游任务的头。通过预训练等方式进行微调即可。

猜你喜欢

转载自blog.csdn.net/qq_43428929/article/details/129984026
ViT