[Computer Vision] ViT: Interpretation of code line by line

1. Code

import torch
import torch.nn as nn
from einops import rearrange

from self_attention_cv import TransformerEncoder


class ViT(nn.Module):
    def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
        """
        Args:
            img_dim: the spatial image size
            in_channels: number of img channels
            patch_dim: desired patch dim
            num_classes: classification task classes
            dim: the linear layer's dim to project the patches for MHSA
            blocks: number of transformer blocks
            heads: number of heads
            dim_linear_block: inner dim of the transformer linear block
            dim_head: dim head in case you want to define it. defaults to dim/heads
            dropout: for pos emb and transformer
            transformer: in case you want to provide another transformer implementation
            classification: creates an extra CLS token
        """
        super().__init__()
        assert img_dim % patch_dim == 0, f'patch size {
      
      patch_dim} not divisible'
        self.p = patch_dim
        self.classification = classification
        tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)
        self.dim = dim
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        self.project_patches = nn.Linear(self.token_dim, dim)

        self.emb_dropout = nn.Dropout(dropout)
        if self.classification:
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
            self.mlp_head = nn.Linear(dim, num_classes)
        else:
            self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

        if transformer is None:
            self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                  dim_head=self.dim_head,
                                                  dim_linear_block=dim_linear_block,
                                                  dropout=dropout)
        else:
            self.transformer = transformer

    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])

    def forward(self, img, mask=None):
        batch_size = img.shape[0]
        img_patches = rearrange(
            img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.p, patch_y=self.p)
        # project patches with linear layer + add pos emb
        img_patches = self.project_patches(img_patches)

        if self.classification:
            img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

        patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

        # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
        y = self.transformer(patch_embeddings, mask)

        if self.classification:
            # we index only the cls token for classification. nlp tricks :P
            return self.mlp_head(y[:, 0, :])
        else:
            return y

Two, code interpretation

2.1 General understanding

This code is a PyTorch implementation of the Vision Transformer (ViT) model.

ViT is an image classification model based on the Transformer architecture. Its main idea is to divide the image into patches of fixed size, and treat these patches as tokens and input them into the Transformer for feature extraction and classification.

The following is an interpretation of the code:

  1. The ViT class inherits from nn.Modulethe class, and its constructor has a series of parameters, including the size of the input image, the size of the patch, the number of output categories, the number of heads in the attention mechanism, and so on.
  2. project_patchesThe function maps each patch to a d-dimensional feature space through a fully connected layer.
  3. If classification = True, an additional CLS token is added to the beginning of the input token sequence, that is, a CLS token of shape [1, 1, d] is added for each image. At the same time, absolute position encoding is used in ViT, so a 1D position encoding vector is added, whose shape is [num_patches + 1, d], where num_patches indicates the number of patches the image is divided into. If classification = False, no CLS token is added.
  4. forwardThe function first divides the input image into patches, and maps each patch to the d-dimensional feature space through the project_patches function. Next, add the position encoding vector to the mapped patch feature vector, and perform dropout processing. If classification=True, add the CLS token at the beginning of the feature sequence. Then these features are input into Transformer for feature extraction. Finally, output the classification result. If classification=True, only the classification result of CLS token will be returned.

2.2 Detailed understanding

from self_attention_cv import TransformerEncoder

self_attention_cvis a PyTorch-based library that provides modules and networks that use self-attention mechanisms in computer vision tasks, such as Transformer Encoderand Attention Modules.

It is mainly aimed at tasks such as image classification, object detection, and semantic segmentation, and supports the implementation of a variety of self-attention modules, including Simplified Self-Attention, , Full Self-Attentionand Local Self-Attention. In addition, the library also provides the implementation of some common computer vision task models, such as Vision Transformer(ViT)and Swin Transformerso on.

TransformerEncoderis an encoder with a self-attention mechanism that converts an input sequence into an encoded sequence. The self-attention mechanism allows the model to weight the representation of each position with respect to other positions in the input sequence. This mechanism is widely used in natural language processing. For example, models such as BERT and GPT all use the self-attention mechanism.

TransformerEncoderIt is implemented based on PyTorch and can be used in computer vision tasks, such as image classification, object detection, semantic segmentation, etc. It supports features such as multi-head attention, residual connections, and LayerNorm. In this code, the Transformer part of the ViT model uses TransformerEncoder as the default implementation.

def __init__(self, *,
                img_dim,
                in_channels=3,
                patch_dim=16,
                num_classes=10,
                dim=512,
                blocks=6,
                heads=4,
                dim_linear_block=1024,
                dim_head=None,
                dropout=0, transformer=None, classification=True):
    super().__init__()
    assert img_dim % patch_dim == 0, f'patch size {
      
      patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification
    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

    if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

This code defines a PyTorch model class called ViT, which is a visual Transformer model implemented using Self-Attention. The main parameters include:

  • img_dim: The space size of the input image
  • in_channels: The number of channels of the input image
  • patch_dim: Divide the image into fixed-size patch sizes
  • num_classes: the number of categories for the classification task
  • dim: Dimensions of the linear layer used to project each patch into the MHSA space
  • blocks: the number of blocks in the Transformer model
  • heads: number of attention heads
  • dim_linear_block: Dimensions inside the linear block
  • dim_head: Dimensions of each head, defaults to dim/heads if not specified
  • dropout: Dropout probability for position encoding and Transformer
  • transformer: Optional TransformerEncoder class instance
  • classification: Whether to include additional CLS markers for classification tasks
def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
    super().__init__()

The constructor of the ViT class is defined here, which contains multiple parameters, including the input image size img_dim, the number of input channels in_channels, the block size patch_dim, the number of categories, the num_classesembedding dimension dim, the number of blocks of the Transformer encoder blocks, the number of heads heads, and the dimension of the linear block dim_linear_block. Note Force head dimension dim_head, dropout probability dropout, optional Transformer encoder transformer, and flag for classification classification.

    assert img_dim % patch_dim == 0, f'patch size {
      
      patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification

This checks img_dimwhether is patch_dimdivisible by and raises an assertion error if not. At the same time, patch_dimstore self.pin , and store the flag whether to classify self.classificationin .

    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

Here the number of divisible blocks in the input image is counted tokensand the dimension of each block self.token_dimis set to in_channels * (patch_dim ** 2).

dimStore the embedding dimension self.diminto , and set the attention head dimension depending on dim_headwhether is None self.dim_head. self.project_patchesis a linear layer that projects each block into the embedding space.

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

Here the layer of the embedding layer is defined Dropout, and according to the flag of whether to classify, set the class flag self.cls_token, position embedding self.pos_emb1Dand MLPheader self.mlp_head. If you are not sorting, and are not self.cls_tokenneeded self.mlp_head.

if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

self.emb_dropout = nn.Dropout(dropout):A layer is defined dropoutfor dropout after embedding.

if self.classification::If it is a classification task, perform the following operations, otherwise skip.

self.cls_token = nn.Parameter(torch.randn(1, 1, dim)):A trainable parameter is defined cls_token, representing the classification token, which is a 1x1xdim tensor, where dim represents the embedding dimension.

self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim)):A trainable parameter pos_emb1D is defined, which represents position embedding, which is a tensor of (tokens+1)xdim, where tokens represent the number of patches the image is divided into, and dim represents the embedding dimension.

self.mlp_head = nn.Linear(dim, num_classes):Defines a fully connected layer that maps embeddings to the number of output categories.

Finally, choose to use the default TransformerEncoder or the imported transformer according to the parameters passed in. If not passed in, the default TransformerEncoder is used, otherwise the passed in transformer is used.

def expand_cls_to_batch(self, batch):
    """
    Args:
        batch: batch size
    Returns: cls token expanded to the batch size
    """
    return self.cls_token.expand([batch, -1, -1])

The function of this method is to expand the classification token in Transformer to the number of samples in the whole batch. It takes a batch parameter as the batch size and returns a tensor of shape [batch, 1, dim], where dim is the dimension size of the Transformer model. In this method, the expand() method of PyTorch is used to implement the expansion operation.

def forward(self, img, mask=None):
    batch_size = img.shape[0]
    img_patches = rearrange(
        img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                            patch_x=self.p, patch_y=self.p)
    # project patches with linear layer + add pos emb
    img_patches = self.project_patches(img_patches)

    if self.classification:
        img_patches = torch.cat(
            (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

    patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

    # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
    y = self.transformer(patch_embeddings, mask)

    if self.classification:
        # we index only the cls token for classification. nlp tricks :P
        return self.mlp_head(y[:, 0, :])
    else:
        return y

In forwardthe function , receive the input imgand mask.

The number is calculated by img_dimand , where tokens is the number of blocks the image is divided into.patch_dimtokens

Divide the input img into patches, and reorganize them into tensors [batch_size, tokens, patch_dim * patch_dim * in_channels]of .

Each patch is mapped to the dim dimension by Linearthe layer , and the position encoding is added pos_emb1D.

If it is for a classification task, insert one at the beginning of the sequence CLS token, and then concatenate with the processed patch tensor column by column.

Applies dropout to patch_embeddings and feeds into TransformerEncoder, returns output tensor y with shape [batch_size, tokens, dim].

If it is used for classification tasks, take the CLS token from y, input it into a Linear layer for classification, and output the classification result.

If it is not a classification task, return y directly.

Guess you like

Origin blog.csdn.net/wzk4869/article/details/130488137