[Computer Vision] ViT: コードの行ごとの解釈

1. コード

import torch
import torch.nn as nn
from einops import rearrange

from self_attention_cv import TransformerEncoder


class ViT(nn.Module):
    def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
        """
        Args:
            img_dim: the spatial image size
            in_channels: number of img channels
            patch_dim: desired patch dim
            num_classes: classification task classes
            dim: the linear layer's dim to project the patches for MHSA
            blocks: number of transformer blocks
            heads: number of heads
            dim_linear_block: inner dim of the transformer linear block
            dim_head: dim head in case you want to define it. defaults to dim/heads
            dropout: for pos emb and transformer
            transformer: in case you want to provide another transformer implementation
            classification: creates an extra CLS token
        """
        super().__init__()
        assert img_dim % patch_dim == 0, f'patch size {
      
      patch_dim} not divisible'
        self.p = patch_dim
        self.classification = classification
        tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)
        self.dim = dim
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        self.project_patches = nn.Linear(self.token_dim, dim)

        self.emb_dropout = nn.Dropout(dropout)
        if self.classification:
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
            self.mlp_head = nn.Linear(dim, num_classes)
        else:
            self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

        if transformer is None:
            self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                  dim_head=self.dim_head,
                                                  dim_linear_block=dim_linear_block,
                                                  dropout=dropout)
        else:
            self.transformer = transformer

    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])

    def forward(self, img, mask=None):
        batch_size = img.shape[0]
        img_patches = rearrange(
            img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.p, patch_y=self.p)
        # project patches with linear layer + add pos emb
        img_patches = self.project_patches(img_patches)

        if self.classification:
            img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

        patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

        # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
        y = self.transformer(patch_embeddings, mask)

        if self.classification:
            # we index only the cls token for classification. nlp tricks :P
            return self.mlp_head(y[:, 0, :])
        else:
            return y

二、コード解釈

2.1 大体理解

このコードは、Vision Transformer (ViT) モデルの PyTorch 実装です。

ViT は、Transformer アーキテクチャに基づく画像分類モデルです. その主なアイデアは、画像を固定サイズのパッチに分割し、これらのパッチをトークンとして扱い、特徴抽出と分類のために Transformer に入力することです.

以下は、コードの解釈です。

  1. ViT クラスはnn.Moduleクラスから継承し、そのコンストラクターには、入力画像のサイズ、パッチのサイズ、出力カテゴリの数、アテンション メカニズムのヘッドの数など、一連のパラメーターがあります。
  2. project_patchesこの関数は、全結合層を介して各パッチを d 次元の特徴空間にマッピングします。
  3. の場合classification = True、追加の CLS トークンが入力トークン シーケンスの先頭に追加されます。つまり、形状 [1, 1, d] の CLS トークンが各画像に追加されます。同時に、ViT では絶対位置エンコードが使用されるため、[num_patches + 1, d] の形状を持つ 1D 位置エンコード ベクトルが追加されます。ここで、num_patches は画像が分割されるパッチの数を示します。classification = False の場合、CLS トークンは追加されません。
  4. forwardこの関数は、最初に入力画像をパッチに分割し、project_patches 関数を介して各パッチを d 次元の特徴空間にマッピングします。次に、マッピングされたパッチ特徴ベクトルに位置エンコード ベクトルを追加し、ドロップアウト処理を行います。classification=True の場合、機能シーケンスの先頭に CLS トークンを追加します。次に、これらの特徴が特徴抽出のために Transformer に入力されます。最後に分類結果を出力します.classification=Trueの場合はCLSトークンの分類結果のみを返します.

2.2 詳細な理解

from self_attention_cv import TransformerEncoder

self_attention_cvTransformer Encoderは、や などのコンピューター ビジョン タスクで自己注意メカニズムを使用するモジュールとネットワークを提供する PyTorch ベースのライブラリですAttention Modules

Simplified Self-Attention主に、画像分類、オブジェクト検出、セマンティック セグメンテーションなどのタスクを対象としており、、、、、Full Self-Attentionなどのさまざまな自己注意モジュールの実装をサポートしていますLocal Self-Attentionさらに、ライブラリは、などの一般的なコンピューター ビジョン タスク モデルの実装も提供しVision Transformer(ViT)ますSwin Transformer

TransformerEncoder入力シーケンスをエンコードされたシーケンスに変換する自己注意メカニズムを備えたエンコーダーです。自己注意メカニズムにより、モデルは入力シーケンス内の他の位置に関して各位置の表現に重みを付けることができます。このメカニズムは自然言語処理で広く使用されており、たとえば、BERT や GPT などのモデルはすべて自己注意メカニズムを使用しています。

TransformerEncoderPyTorch に基づいて実装されており、画像分類、オブジェクト検出、セマンティック セグメンテーションなどのコンピューター ビジョン タスクで使用できます。マルチヘッド アテンション、残留接続、LayerNorm などの機能をサポートしています。このコードでは、ViT モデルの Transformer 部分がデフォルトの実装として TransformerEncoder を使用しています。

def __init__(self, *,
                img_dim,
                in_channels=3,
                patch_dim=16,
                num_classes=10,
                dim=512,
                blocks=6,
                heads=4,
                dim_linear_block=1024,
                dim_head=None,
                dropout=0, transformer=None, classification=True):
    super().__init__()
    assert img_dim % patch_dim == 0, f'patch size {
      
      patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification
    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

    if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

このコードは、ViT と呼ばれる PyTorch モデル クラスを定義します。これは、Self-Attention を使用して実装された視覚的な Transformer モデルです。主なパラメータは次のとおりです。

  • img_dim: 入力画像の空間サイズ
  • in_channels: 入力画像のチャンネル数
  • patch_dim: イメージを固定サイズのパッチ サイズに分割します。
  • num_classes: 分類タスクのカテゴリ数
  • dim: 各パッチを MHSA 空間に投影するために使用される線形層の次元
  • blocks: Transformer モデルのブロック数
  • heads: 注意ヘッドの数
  • dim_linear_block:リニアブロック内寸法
  • dim_head: 各ヘッドの寸法。指定がない場合、デフォルトはdim/heads
  • dropout: 位置エンコーディングと Transformer のドロップアウト確率
  • transformer: オプションの TransformerEncoder クラス インスタンス
  • classification: 分類タスクに追加の CLS マーカーを含めるかどうか
def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
    super().__init__()

ViT クラスのコンストラクターはここで定義され、入力画像サイズimg_dim、入力チャンネル数in_channels、ブロック サイズpatch_dim、カテゴリ数、num_classes埋め込みdim次元、Transformer エンコーダーのブロック数blocks、ヘッドの数heads、および線形ブロックの次元dim_linear_block注意 ヘッドの次元dim_head、ドロップアウト確率dropout、オプションの Transformer エンコーダーtransformer、および分類用のフラグを強制しますclassification

    assert img_dim % patch_dim == 0, f'patch size {
      
      patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification

これはimg_dimpatch_dimで割り切れるかどうかをチェックし、割り切れない場合はアサーション エラーを発生させます。同時にpatch_dimself.pに格納し、self.classification分類。

    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

ここで、入力画像内の分割可能なブロックの数がカウントされtokens、各ブロックの次元が にself.token_dim設定されますin_channels * (patch_dim ** 2)

埋め込みdim次元self.dimを に格納し、 が Nonedim_headかどうかself.dim_headself.project_patches各ブロックを埋め込み空間に射影する線形層です。

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

ここでは、埋め込み層の層を定義しDropout、分類するかどうかのフラグに応じて、クラスフラグself.cls_token、位置埋め込みself.pos_emb1DMLPヘッダーを設定しますself.mlp_headソートされていない場合は、self.cls_token必要self.mlp_head

if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

self.emb_dropout = nn.Dropout(dropout):埋め込み後のドロップアウト用にレイヤーを定義しますdropout

if self.classification::分類タスクの場合は、次の操作を実行します。それ以外の場合はスキップします。

self.cls_token = nn.Parameter(torch.randn(1, 1, dim)):cls_token1x1xdim テンソルである分類トークンを表すトレーニング可能なパラメーターが定義されています。ここで、dim は埋め込み次元を表します。

self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim)):(tokens+1)xdim のテンソルである位置の埋め込みを表すトレーニング可能なパラメーター pos_emb1D が定義されています。ここで、tokens は画像が分割されるパッチの数を表し、dim は埋め込み次元を表します。

self.mlp_head = nn.Linear(dim, num_classes):埋め込みを出力カテゴリの数にマッピングする全結合層を定義します。

最後に、渡されたパラメーターに従って、デフォルトの TransformerEncoder またはインポートされたトランスフォーマーを使用することを選択します。渡されない場合は、デフォルトの TransformerEncoder が使用されます。それ以外の場合は、渡されたトランスフォーマーが使用されます。

def expand_cls_to_batch(self, batch):
    """
    Args:
        batch: batch size
    Returns: cls token expanded to the batch size
    """
    return self.cls_token.expand([batch, -1, -1])

このメソッドの機能は、Transformer の分類トークンをバッチ全体のサンプル数に拡張することです。バッチ パラメーターをバッチ サイズとして取り、形状 [batch, 1, dim] のテンソルを返します。ここで、dim は Transformer モデルの次元サイズです。このメソッドでは、展開操作を実装するために PyTorch の expand() メソッドが使用されます。

def forward(self, img, mask=None):
    batch_size = img.shape[0]
    img_patches = rearrange(
        img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                            patch_x=self.p, patch_y=self.p)
    # project patches with linear layer + add pos emb
    img_patches = self.project_patches(img_patches)

    if self.classification:
        img_patches = torch.cat(
            (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

    patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

    # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
    y = self.transformer(patch_embeddings, mask)

    if self.classification:
        # we index only the cls token for classification. nlp tricks :P
        return self.mlp_head(y[:, 0, :])
    else:
        return y

forward関数で、入力imgと を受け取りますmask

数はimg_dimによって計算されます。ここで、tokens はイメージが分割されるブロックの数です。patch_dimtokens

入力 img をパッチに分割し、再配置関数を使用して形状[batch_size, tokens, patch_dim * patch_dim * in_channels]の。

Linear各パッチはレイヤーによって薄暗い次元にマッピングされ、位置エンコーディングが追加されますpos_emb1D

分類タスクの場合は、シーケンスの先頭に 1 つ挿入しCLS token、処理されたパッチ テンソルと列ごとに連結します。

ドロップアウトを patch_embeddings に適用し、TransformerEncoder にフィードし、出力テンソル y を shape で返します[batch_size, tokens, dim]

分類タスクに使用する場合は、y から CLS トークンを取得し、分類のために Linear レイヤーに入力し、分類結果を出力します。

分類タスクでない場合は、y を直接返します。

おすすめ

転載: blog.csdn.net/wzk4869/article/details/130488137