1. コード
import torch
import torch.nn as nn
from einops import rearrange
from self_attention_cv import TransformerEncoder
class ViT(nn.Module):
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
"""
Args:
img_dim: the spatial image size
in_channels: number of img channels
patch_dim: desired patch dim
num_classes: classification task classes
dim: the linear layer's dim to project the patches for MHSA
blocks: number of transformer blocks
heads: number of heads
dim_linear_block: inner dim of the transformer linear block
dim_head: dim head in case you want to define it. defaults to dim/heads
dropout: for pos emb and transformer
transformer: in case you want to provide another transformer implementation
classification: creates an extra CLS token
"""
super().__init__()
assert img_dim % patch_dim == 0, f'patch size {
patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
def expand_cls_to_batch(self, batch):
"""
Args:
batch: batch size
Returns: cls token expanded to the batch size
"""
return self.cls_token.expand([batch, -1, -1])
def forward(self, img, mask=None):
batch_size = img.shape[0]
img_patches = rearrange(
img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
patch_x=self.p, patch_y=self.p)
# project patches with linear layer + add pos emb
img_patches = self.project_patches(img_patches)
if self.classification:
img_patches = torch.cat(
(self.expand_cls_to_batch(batch_size), img_patches), dim=1)
patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)
# feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
y = self.transformer(patch_embeddings, mask)
if self.classification:
# we index only the cls token for classification. nlp tricks :P
return self.mlp_head(y[:, 0, :])
else:
return y
二、コード解釈
2.1 大体理解
このコードは、Vision Transformer (ViT) モデルの PyTorch 実装です。
ViT は、Transformer アーキテクチャに基づく画像分類モデルです. その主なアイデアは、画像を固定サイズのパッチに分割し、これらのパッチをトークンとして扱い、特徴抽出と分類のために Transformer に入力することです.
以下は、コードの解釈です。
- ViT クラスは
nn.Module
クラスから継承し、そのコンストラクターには、入力画像のサイズ、パッチのサイズ、出力カテゴリの数、アテンション メカニズムのヘッドの数など、一連のパラメーターがあります。 project_patches
この関数は、全結合層を介して各パッチを d 次元の特徴空間にマッピングします。- の場合
classification = True
、追加の CLS トークンが入力トークン シーケンスの先頭に追加されます。つまり、形状 [1, 1, d] の CLS トークンが各画像に追加されます。同時に、ViT では絶対位置エンコードが使用されるため、[num_patches + 1, d] の形状を持つ 1D 位置エンコード ベクトルが追加されます。ここで、num_patches は画像が分割されるパッチの数を示します。classification = False の場合、CLS トークンは追加されません。 forward
この関数は、最初に入力画像をパッチに分割し、project_patches 関数を介して各パッチを d 次元の特徴空間にマッピングします。次に、マッピングされたパッチ特徴ベクトルに位置エンコード ベクトルを追加し、ドロップアウト処理を行います。classification=True の場合、機能シーケンスの先頭に CLS トークンを追加します。次に、これらの特徴が特徴抽出のために Transformer に入力されます。最後に分類結果を出力します.classification=Trueの場合はCLSトークンの分類結果のみを返します.
2.2 詳細な理解
from self_attention_cv import TransformerEncoder
self_attention_cv
Transformer Encoder
は、や などのコンピューター ビジョン タスクで自己注意メカニズムを使用するモジュールとネットワークを提供する PyTorch ベースのライブラリですAttention Modules
。
Simplified Self-Attention
主に、画像分類、オブジェクト検出、セマンティック セグメンテーションなどのタスクを対象としており、、、、、Full Self-Attention
などのさまざまな自己注意モジュールの実装をサポートしていますLocal Self-Attention
。さらに、ライブラリは、などの一般的なコンピューター ビジョン タスク モデルの実装も提供しVision Transformer(ViT)
ますSwin Transformer
。
TransformerEncoder
入力シーケンスをエンコードされたシーケンスに変換する自己注意メカニズムを備えたエンコーダーです。自己注意メカニズムにより、モデルは入力シーケンス内の他の位置に関して各位置の表現に重みを付けることができます。このメカニズムは自然言語処理で広く使用されており、たとえば、BERT や GPT などのモデルはすべて自己注意メカニズムを使用しています。
TransformerEncoder
PyTorch に基づいて実装されており、画像分類、オブジェクト検出、セマンティック セグメンテーションなどのコンピューター ビジョン タスクで使用できます。マルチヘッド アテンション、残留接続、LayerNorm などの機能をサポートしています。このコードでは、ViT モデルの Transformer 部分がデフォルトの実装として TransformerEncoder を使用しています。
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
super().__init__()
assert img_dim % patch_dim == 0, f'patch size {
patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
このコードは、ViT と呼ばれる PyTorch モデル クラスを定義します。これは、Self-Attention を使用して実装された視覚的な Transformer モデルです。主なパラメータは次のとおりです。
img_dim
: 入力画像の空間サイズin_channels
: 入力画像のチャンネル数patch_dim
: イメージを固定サイズのパッチ サイズに分割します。num_classes
: 分類タスクのカテゴリ数dim
: 各パッチを MHSA 空間に投影するために使用される線形層の次元blocks
: Transformer モデルのブロック数heads
: 注意ヘッドの数dim_linear_block
:リニアブロック内寸法dim_head
: 各ヘッドの寸法。指定がない場合、デフォルトはdim/headsdropout
: 位置エンコーディングと Transformer のドロップアウト確率transformer
: オプションの TransformerEncoder クラス インスタンスclassification
: 分類タスクに追加の CLS マーカーを含めるかどうか
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
super().__init__()
ViT クラスのコンストラクターはここで定義され、入力画像サイズimg_dim
、入力チャンネル数in_channels
、ブロック サイズpatch_dim
、カテゴリ数、num_classes
埋め込みdim
次元、Transformer エンコーダーのブロック数blocks
、ヘッドの数heads
、および線形ブロックの次元dim_linear_block
注意 ヘッドの次元dim_head
、ドロップアウト確率dropout
、オプションの Transformer エンコーダーtransformer
、および分類用のフラグを強制しますclassification
。
assert img_dim % patch_dim == 0, f'patch size {
patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
これはimg_dim
がpatch_dim
で割り切れるかどうかをチェックし、割り切れない場合はアサーション エラーを発生させます。同時にpatch_dim
、self.p
に格納し、self.classification
分類。
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
ここで、入力画像内の分割可能なブロックの数がカウントされtokens
、各ブロックの次元が にself.token_dim
設定されますin_channels * (patch_dim ** 2)
。
埋め込みdim
次元self.dim
を に格納し、 が Nonedim_head
かどうかself.dim_head
。self.project_patches
各ブロックを埋め込み空間に射影する線形層です。
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
ここでは、埋め込み層の層を定義しDropout
、分類するかどうかのフラグに応じて、クラスフラグself.cls_token
、位置埋め込みself.pos_emb1D
、MLP
ヘッダーを設定しますself.mlp_head
。ソートされていない場合は、self.cls_token
必要self.mlp_head
。
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
self.emb_dropout = nn.Dropout(dropout):
埋め込み後のドロップアウト用にレイヤーを定義しますdropout
。
if self.classification::
分類タスクの場合は、次の操作を実行します。それ以外の場合はスキップします。
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)):
cls_token
1x1xdim テンソルである分類トークンを表すトレーニング可能なパラメーターが定義されています。ここで、dim は埋め込み次元を表します。
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim)):
(tokens+1)xdim のテンソルである位置の埋め込みを表すトレーニング可能なパラメーター pos_emb1D が定義されています。ここで、tokens は画像が分割されるパッチの数を表し、dim は埋め込み次元を表します。
self.mlp_head = nn.Linear(dim, num_classes):
埋め込みを出力カテゴリの数にマッピングする全結合層を定義します。
最後に、渡されたパラメーターに従って、デフォルトの TransformerEncoder またはインポートされたトランスフォーマーを使用することを選択します。渡されない場合は、デフォルトの TransformerEncoder が使用されます。それ以外の場合は、渡されたトランスフォーマーが使用されます。
def expand_cls_to_batch(self, batch):
"""
Args:
batch: batch size
Returns: cls token expanded to the batch size
"""
return self.cls_token.expand([batch, -1, -1])
このメソッドの機能は、Transformer の分類トークンをバッチ全体のサンプル数に拡張することです。バッチ パラメーターをバッチ サイズとして取り、形状 [batch, 1, dim] のテンソルを返します。ここで、dim は Transformer モデルの次元サイズです。このメソッドでは、展開操作を実装するために PyTorch の expand() メソッドが使用されます。
def forward(self, img, mask=None):
batch_size = img.shape[0]
img_patches = rearrange(
img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
patch_x=self.p, patch_y=self.p)
# project patches with linear layer + add pos emb
img_patches = self.project_patches(img_patches)
if self.classification:
img_patches = torch.cat(
(self.expand_cls_to_batch(batch_size), img_patches), dim=1)
patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)
# feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
y = self.transformer(patch_embeddings, mask)
if self.classification:
# we index only the cls token for classification. nlp tricks :P
return self.mlp_head(y[:, 0, :])
else:
return y
forward
関数で、入力img
と を受け取りますmask
。
数はimg_dim
とによって計算されます。ここで、tokens はイメージが分割されるブロックの数です。patch_dim
tokens
入力 img をパッチに分割し、再配置関数を使用して形状[batch_size, tokens, patch_dim * patch_dim * in_channels]
の。
Linear
各パッチはレイヤーによって薄暗い次元にマッピングされ、位置エンコーディングが追加されますpos_emb1D
。
分類タスクの場合は、シーケンスの先頭に 1 つ挿入しCLS token
、処理されたパッチ テンソルと列ごとに連結します。
ドロップアウトを patch_embeddings に適用し、TransformerEncoder にフィードし、出力テンソル y を shape で返します[batch_size, tokens, dim]
。
分類タスクに使用する場合は、y から CLS トークンを取得し、分類のために Linear レイヤーに入力し、分類結果を出力します。
分類タスクでない場合は、y を直接返します。