Build Vision Transformers from scratch (Pytorch version)

For more content, follow the official account "The Invincible Zhang Dao "

preface

After Vision Transformers (ViT) was proposed by Dosovitskiy et. al. in 2020, it has gradually occupied a dominant position in the field of computer vision, and has achieved good performance in downstream tasks such as image classification, target detection, and semantic segmentation, setting off the transformer series in CV field waves. Here I will introduce how to implement the ViT model step by step based on the Pytorch framework from scratch.

foreword

If you are not familiar with the Transformer model used in natural language processing (NLP), you may be a little confused about the application of Transformer in the CV field, and you are not clear about the use of the ViT model on images. So, don’t worry, here will be how to start from scratch Implementing my first ViT (using PyTorch), here we go!

define tasks

For novices, we choose the entry dataset, our MNIST handwriting dataset for image classification. Although the goal is simple, we can clarify the entire context of the ViT model based on this image classification task. Briefly introduce the MNIST data set, which is a data set of handwritten digits ([0–9]), and the images are all grayscale images of 28x28 size.
First import some modules of pytorch that need to be used:

import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor

Let's create the main function to preprocess the MNIST dataset, instantiate the model, define the loss, use the Adam optimizer, train for 50 epochs, and then calculate the accuracy on the test set.

def main():
    # Loading data
    transform = ToTensor()

    train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
    test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=16)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=16)

    # Defining model and training options
    model = MyViT((1, 28, 28), n_patches=7, hidden_d=20, n_heads=2, out_d=10) # TODO define ViT model 
    N_EPOCHS = 50
    LR = 0.01

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in range(N_EPOCHS):
        train_loss = 0.0
        for batch in train_loader:
            x, y = batch
            y_hat = model(x)
            loss = criterion(y_hat, y) / len(x)

            train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {
      
      epoch + 1}/{
      
      N_EPOCHS} loss: {
      
      train_loss:.2f}")

 # Test loop
    correct, total = 0, 0
    test_loss = 0.0
    for batch in test_loader:
        x, y = batch
        y_hat = model(x)
        loss = criterion(y_hat, y) / len(x)
        test_loss += loss

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).item()
        total += len(x)
    print(f"Test loss: {
      
      test_loss:.2f}")
    print(f"Test accuracy: {
      
      correct / total * 100:.2f}%")

After building the entire training and testing framework, let's tackle the construction of the ViT model. The task of the model is to classify (Nx1x28x28) images. We first define an empty nn.Module class, and then gradually fill it:

class MyViT(nn.Module):
    def __init__(self):
        # Super constructor
        super(MyViT, self).__init__()

    def forward(self, images):
        pass

ViT architecture

Since pytorch and most DL frameworks provide autograd calculations, we only need to care about the forward pass process of the ViT model. The optimizer of the model has been defined in the training framework, and the pytorch framework will be responsible for backpropagating the gradient and training the parameters of the model. .
insert image description here
The above figure shows the entire model network architecture of ViT, from which we can see that the input image (a) is first cut into patch sub-pictures of equal size, and then each sub-picture is put into Linear Embedding. For each picture vector Do a full connection operation, do pre-processing of transformer input, why add this Linear Embedding, here you can refer to the interpretation of the author in [1], the author talked about the transformer series in great detail, recommended by the wall crack. After coming out of the Linear Embedding layer, add Positonal encoding to take into account the relative position information of each patch in the image, followed by the process of transformer Encoder, and then add the classification head of MLP to output the classification of the image.
insert image description here
The picture above is for the author in [1] to facilitate everyone's understanding, adding various dimensions in the forward process. Below we build ViT through 6 main steps.

Patchifying and Linear Maps

The Transformer encoder was mainly used for serialized data such as NLP at the beginning. The first step in using it in the CV field is to process "serialized" images. The processing method here is to decompose an image into multiple sub-images. Map each subimage into a vector.
On the MNIST dataset, we divide each (1x28x28) image into 7x7 blocks, each block size is 4x4 (if the block cannot be completely divisible, the image padding needs to be filled), so that we can get 49 sub-images from a single image . Reshape the original graph to:
(N, PxP, HxC/P x WxC/P) = (N, 7x7, 4x4) = (N, 49, 16)
Note that although each subgraph is 1x4x4 in size, we will It flattens to a 16-dimensional vector. Also, MNIST has only one color channel. If there are multiple color channels, they are also flattened into the vector.
insert image description here
We implement the above functions for the code:

class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=7):
        # Super constructor
        super(MyViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"


    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)
        return patches 

Now that we get the flattened patches, which are vectors, and change the dimensions through a linear map, which can be mapped to any vector size, we add a hidden_d parameter to the class constructor for "hidden dimensions". Here, the hidden dimension is 8, so that we map each 16-dimensional patch to an 8-dimensional patch, and the implementation code is as follows.

class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8):
        # Super constructor
        super(MyViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)
         # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)

        return tokens  

Add category tags

After adding the hidden layer, in order to complete the classification task, we need to add classification marks. The main reason is to refer to [1], here we only do the implementation. Now we can add a parameter to our model that converts our (N, 49, 8) tensors to (N, 50, 8) tensors (adding special markers to each sequence).

class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
        # Super constructor
        super(MyViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)
         # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)
        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

        return tokens 

Note that (N,49,8) → (N,50,8) implementation here may not be optimal. Also, note that the classification marker needs to be placed in the first marker position of each sequence. When we complete the final MLP, we need to correspond to the corresponding position.

Add location code

For position encoding, refer to the input of position marking in the transformer model. Although this position embedding can be learned theoretically, some people have studied this area and suggested that we can only add sine and cosine waves.
insert image description here

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return resul

insert image description here
From the plotted heatmap, it is seen that all the "horizontal lines" are different from each other, so the sample locations can be distinguished.

Positional encoding can now be added to the model after the linear mapping and adding categorical markers:

class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
        # Super constructor
        super(MyViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        # (In forward method)



    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)

        # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)

        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)

        return tokens

Since the token size is (N, 50, 8), we need to repeat the (50, 8) position encoding matrix N times.

LN, MSA and residual connections

This is the most complicated step. We need to do layer normalization on the tokens first, then apply the multi-head attention mechanism, and finally add a residual connection (connect the input before the LN and the output after the multi-head attention).

LN

We typically apply LN to (N, d) inputs, where d is the dimension. It wasn't until I realized ViT that I found that nn.LayerNorm can be applied to multiple dimensions:
insert image description here
After running the (N, 50, 8) tensor through LN, the mean value of each 50x8 matrix is ​​0 and the standard deviation is 1, and the dimension remains unchanged.

multi-headed self-attention

We now need to implement subgraph c of the architecture diagram. Here is the multi-head attention mechanism. If you don’t know the implementation process, see [2]. In short: for a single image, we want each patch to be updated according to some similarity measure with other patches. By linearly mapping each patch (now an 8-dimensional vector in the example) to 3 different vectors: q, k, and v (query, key, value). Then, for a single patch, we will compute the dot product between its q vector and all k vectors, divide by the square root d of the dimensions of these vectors, apply the resulting softmax activation, and finally associate the computed result with the v of the different k vectors Vector multiplication, the entire calculation formula is as follows.
insert image description here
In this way, each patch gets a new value which is the similarity to other patches (after linear mapping to q, k and v). The whole process is single-headed, and the whole process is repeated multiple times for multiple heads. After all the results are obtained, they are concatenated together by a linear layer.

Since quite a lot of calculations are performed, create a new class for MSA:

class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {
      
      d} into {
      
      n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        self.k_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        self.v_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

Note that for each head we create different Q, K and V mapping functions (here a square matrix of size 4x4).

Since the input is a sequence of size (N, 50, 8), we use 2 heads, so we will at some point have (N, 50, 2, 4) tensors, using the nn.Linear(4, 4 ) module , then back to a (N, 50, 8) tensor after concatenation. Also, using a loop is not the most efficient way to compute multi-head self-attention, but the code is cleaner.

residual connection

A residual connection will be added which adds our original (N, 50, 8) tensor to the (N, 50, 8) obtained after LN and MSA.

class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
        # Super constructor
        super(MyViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        # (In forward method)

        # 4a) Layer normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 4b) Multi-head Self Attention (MSA) and classification token
        self.msa = MyMSA(self.hidden_d, n_heads)



    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)

        # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)

        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)

        # TRANSFORMER ENCODER BEGINS ###################################
        # NOTICE: MULTIPLE ENCODER BLOCKS CAN BE STACKED TOGETHER ######
        # Running Layer Normalization, MSA and residual connection
        out = tokens + self.msa(self.ln1(tokens))

        return out 

Note that if we now run random (3, 1, 28, 28) images of MNIST through our model, we still get results of shape (3, 50, 8).

LN, MLP and Residual Connections

Continue to the following network, pass the current tensor through another LN and MLP, and connect it through the residual, um, build it up like this.

class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
        # Super constructor
        super(MyViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        # (In forward method)

        # 4a) Layer normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 4b) Multi-head Self Attention (MSA) and classification token
        self.msa = MyMSA(self.hidden_d, n_heads)

        # 5a) Layer normalization 2
        self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 5b) Encoder MLP
        self.enc_mlp = nn.Sequential(
            nn.Linear(self.hidden_d, self.hidden_d),
            nn.ReLU()
        )


    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)

        # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)

        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)

        # TRANSFORMER ENCODER BEGINS ###################################
        # NOTICE: MULTIPLE ENCODER BLOCKS CAN BE STACKED TOGETHER ######
        # Running Layer Normalization, MSA and residual connection
        out = tokens + self.msa(self.ln1(tokens))

        # Running Layer Normalization, MLP and residual connection
        out = out + self.enc_mlp(self.ln2(out))
        # TRANSFORMER ENCODER ENDS   ###################################

        return out 

This way, if our model inputs a random (3, 1, 28, 28) image tensor, the out output will still get a (3, 50, 8) tensor.

ClassificationMLP

Finally, we can extract only the categorical markers (the first markers) from the N sequences, corresponding to where the categorical labels were added, and use each marker to get N categories.

Since we decided that each token is an 8-dimensional vector, and since we have 10 possible numbers, we can implement the classification MLP as a simple 8x10 matrix and use the SoftMax function activation.

class MyViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
        # Super constructor
        super(MyViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        # (In forward method)

        # 4a) Layer normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 4b) Multi-head Self Attention (MSA) and classification token
        self.msa = MyMSA(self.hidden_d, n_heads)

        # 5a) Layer normalization 2
        self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 5b) Encoder MLP
        self.enc_mlp = nn.Sequential(
            nn.Linear(self.hidden_d, self.hidden_d),
            nn.ReLU()
        )

        # 6) Classification MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)

        # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)

        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)

        # TRANSFORMER ENCODER BEGINS ###################################
        # NOTICE: MULTIPLE ENCODER BLOCKS CAN BE STACKED TOGETHER ######
        # Running Layer Normalization, MSA and residual connection
        out = tokens + self.msa(self.ln1(tokens))

        # Running Layer Normalization, MLP and residual connection
        out = out + self.enc_mlp(self.ln2(out))
        # TRANSFORMER ENCODER ENDS   ###################################

        # Getting the classification token only
        out = out[:, 0]

        return self.mlp(out)

The output of our model is now an (N, 10) tensor. ok, you're done!

Now let's see how our model performs. Manually set the Torch seed (set to 0), run under cpu:
insert image description here

epilogue

The authors of the original ViT used a GeLU activation function, a multi-layer MLP, and stacked multiple Transformer encoder blocks together. Here is the simplest beggar version. You can add it based on this later. Follow the official account and reply "vit" to get the complete code.


Paper: https://arxiv.org/abs/2010.11929
Reference:
[1] https://zhuanlan.zhihu.com/p/342261872
[2] https://zhuanlan.zhihu.com/p/340149804

More content is welcome to pay attention to:

insert image description here

Guess you like

Origin blog.csdn.net/zqwwwm/article/details/124265975