Article directory
1. Code
import torch
import torch.nn as nn
from einops import rearrange
from self_attention_cv import TransformerEncoder
class ViT(nn.Module):
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
"""
Args:
img_dim: the spatial image size
in_channels: number of img channels
patch_dim: desired patch dim
num_classes: classification task classes
dim: the linear layer's dim to project the patches for MHSA
blocks: number of transformer blocks
heads: number of heads
dim_linear_block: inner dim of the transformer linear block
dim_head: dim head in case you want to define it. defaults to dim/heads
dropout: for pos emb and transformer
transformer: in case you want to provide another transformer implementation
classification: creates an extra CLS token
"""
super().__init__()
assert img_dim % patch_dim == 0, f'patch size {
patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
def expand_cls_to_batch(self, batch):
"""
Args:
batch: batch size
Returns: cls token expanded to the batch size
"""
return self.cls_token.expand([batch, -1, -1])
def forward(self, img, mask=None):
batch_size = img.shape[0]
img_patches = rearrange(
img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
patch_x=self.p, patch_y=self.p)
# project patches with linear layer + add pos emb
img_patches = self.project_patches(img_patches)
if self.classification:
img_patches = torch.cat(
(self.expand_cls_to_batch(batch_size), img_patches), dim=1)
patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)
# feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
y = self.transformer(patch_embeddings, mask)
if self.classification:
# we index only the cls token for classification. nlp tricks :P
return self.mlp_head(y[:, 0, :])
else:
return y
Two, code interpretation
2.1 General understanding
This code is a PyTorch implementation of the Vision Transformer (ViT) model.
ViT is an image classification model based on the Transformer architecture. Its main idea is to divide the image into patches of fixed size, and treat these patches as tokens and input them into the Transformer for feature extraction and classification.
The following is an interpretation of the code:
- The ViT class inherits from
nn.Module
the class, and its constructor has a series of parameters, including the size of the input image, the size of the patch, the number of output categories, the number of heads in the attention mechanism, and so on. project_patches
The function maps each patch to a d-dimensional feature space through a fully connected layer.- If
classification = True
, an additional CLS token is added to the beginning of the input token sequence, that is, a CLS token of shape [1, 1, d] is added for each image. At the same time, absolute position encoding is used in ViT, so a 1D position encoding vector is added, whose shape is [num_patches + 1, d], where num_patches indicates the number of patches the image is divided into. If classification = False, no CLS token is added. forward
The function first divides the input image into patches, and maps each patch to the d-dimensional feature space through the project_patches function. Next, add the position encoding vector to the mapped patch feature vector, and perform dropout processing. If classification=True, add the CLS token at the beginning of the feature sequence. Then these features are input into Transformer for feature extraction. Finally, output the classification result. If classification=True, only the classification result of CLS token will be returned.
2.2 Detailed understanding
from self_attention_cv import TransformerEncoder
self_attention_cv
is a PyTorch-based library that provides modules and networks that use self-attention mechanisms in computer vision tasks, such as Transformer Encoder
and Attention Modules
.
It is mainly aimed at tasks such as image classification, object detection, and semantic segmentation, and supports the implementation of a variety of self-attention modules, including Simplified Self-Attention
, , Full Self-Attention
and Local Self-Attention
. In addition, the library also provides the implementation of some common computer vision task models, such as Vision Transformer(ViT)
and Swin Transformer
so on.
TransformerEncoder
is an encoder with a self-attention mechanism that converts an input sequence into an encoded sequence. The self-attention mechanism allows the model to weight the representation of each position with respect to other positions in the input sequence. This mechanism is widely used in natural language processing. For example, models such as BERT and GPT all use the self-attention mechanism.
TransformerEncoder
It is implemented based on PyTorch and can be used in computer vision tasks, such as image classification, object detection, semantic segmentation, etc. It supports features such as multi-head attention, residual connections, and LayerNorm. In this code, the Transformer part of the ViT model uses TransformerEncoder as the default implementation.
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
super().__init__()
assert img_dim % patch_dim == 0, f'patch size {
patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
This code defines a PyTorch model class called ViT, which is a visual Transformer model implemented using Self-Attention. The main parameters include:
img_dim
: The space size of the input imagein_channels
: The number of channels of the input imagepatch_dim
: Divide the image into fixed-size patch sizesnum_classes
: the number of categories for the classification taskdim
: Dimensions of the linear layer used to project each patch into the MHSA spaceblocks
: the number of blocks in the Transformer modelheads
: number of attention headsdim_linear_block
: Dimensions inside the linear blockdim_head
: Dimensions of each head, defaults to dim/heads if not specifieddropout
: Dropout probability for position encoding and Transformertransformer
: Optional TransformerEncoder class instanceclassification
: Whether to include additional CLS markers for classification tasks
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
super().__init__()
The constructor of the ViT class is defined here, which contains multiple parameters, including the input image size img_dim
, the number of input channels in_channels
, the block size patch_dim
, the number of categories, the num_classes
embedding dimension dim
, the number of blocks of the Transformer encoder blocks
, the number of heads heads
, and the dimension of the linear block dim_linear_block
. Note Force head dimension dim_head
, dropout probability dropout
, optional Transformer encoder transformer
, and flag for classification classification
.
assert img_dim % patch_dim == 0, f'patch size {
patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
This checks img_dim
whether is patch_dim
divisible by and raises an assertion error if not. At the same time, patch_dim
store self.p
in , and store the flag whether to classify self.classification
in .
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
Here the number of divisible blocks in the input image is counted tokens
and the dimension of each block self.token_dim
is set to in_channels * (patch_dim ** 2)
.
dim
Store the embedding dimension self.dim
into , and set the attention head dimension depending on dim_head
whether is None self.dim_head
. self.project_patches
is a linear layer that projects each block into the embedding space.
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
Here the layer of the embedding layer is defined Dropout
, and according to the flag of whether to classify, set the class flag self.cls_token
, position embedding self.pos_emb1D
and MLP
header self.mlp_head
. If you are not sorting, and are not self.cls_token
needed self.mlp_head
.
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
self.emb_dropout = nn.Dropout(dropout):
A layer is defined dropout
for dropout after embedding.
if self.classification::
If it is a classification task, perform the following operations, otherwise skip.
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)):
A trainable parameter is defined cls_token
, representing the classification token, which is a 1x1xdim tensor, where dim represents the embedding dimension.
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim)):
A trainable parameter pos_emb1D is defined, which represents position embedding, which is a tensor of (tokens+1)xdim, where tokens represent the number of patches the image is divided into, and dim represents the embedding dimension.
self.mlp_head = nn.Linear(dim, num_classes):
Defines a fully connected layer that maps embeddings to the number of output categories.
Finally, choose to use the default TransformerEncoder or the imported transformer according to the parameters passed in. If not passed in, the default TransformerEncoder is used, otherwise the passed in transformer is used.
def expand_cls_to_batch(self, batch):
"""
Args:
batch: batch size
Returns: cls token expanded to the batch size
"""
return self.cls_token.expand([batch, -1, -1])
The function of this method is to expand the classification token in Transformer to the number of samples in the whole batch. It takes a batch parameter as the batch size and returns a tensor of shape [batch, 1, dim], where dim is the dimension size of the Transformer model. In this method, the expand() method of PyTorch is used to implement the expansion operation.
def forward(self, img, mask=None):
batch_size = img.shape[0]
img_patches = rearrange(
img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
patch_x=self.p, patch_y=self.p)
# project patches with linear layer + add pos emb
img_patches = self.project_patches(img_patches)
if self.classification:
img_patches = torch.cat(
(self.expand_cls_to_batch(batch_size), img_patches), dim=1)
patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)
# feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
y = self.transformer(patch_embeddings, mask)
if self.classification:
# we index only the cls token for classification. nlp tricks :P
return self.mlp_head(y[:, 0, :])
else:
return y
In forward
the function , receive the input img
and mask
.
The number is calculated by img_dim
and , where tokens is the number of blocks the image is divided into.patch_dim
tokens
Divide the input img into patches, and reorganize them into tensors [batch_size, tokens, patch_dim * patch_dim * in_channels]
of .
Each patch is mapped to the dim dimension by Linear
the layer , and the position encoding is added pos_emb1D
.
If it is for a classification task, insert one at the beginning of the sequence CLS token
, and then concatenate with the processed patch tensor column by column.
Applies dropout to patch_embeddings and feeds into TransformerEncoder, returns output tensor y with shape [batch_size, tokens, dim]
.
If it is used for classification tasks, take the CLS token from y, input it into a Linear layer for classification, and output the classification result.
If it is not a classification task, return y directly.