Vision Transformer (VIT network architecture)

Paper download link: https://arxiv.org/abs/2010.11929

introduction

1. Comparison between VIT and traditional CNN

There are several key differences between ViT (Vision Transformer) and traditional convolutional neural network (CNN) in image processing:

1. Model structure:

  • ViT: Mainly based on the Transformer structure, no convolutional layers are used.
  • CNN: Use convolutional, pooling and fully connected layers.

2. Input processing:

  • ViT: Divides an image into multiple fixed-size chunks and processes them all at once.
  • CNN: Gradually scans the entire image through convolutional windows.

3. Computational complexity:

  • ViT: Computational complexity may be higher due to the self-attention mechanism.
  • CNN: Usually easier to optimize, relatively low computational complexity.

4. Data dependencies:

  • ViT: Usually requires more data and computing resources for effective training.
  • CNN: Relatively easier to train on small datasets.

2. Why is Transformer needed in the image task?

In the history of deep learning, Convolutional Neural Networks (CNNs) have long been the mainstream architecture for processing image tasks. However, with Transformer's successful application to natural language processing (NLP) tasks, researchers began to consider its potential in computer vision.

Flexible Global Attention Mechanism

  • Global context: Unlike CNN with local receptive fields, Transformer has a global receptive field, which allows it to fuse information on the entire image. This global context can be very useful in certain tasks, such as image segmentation, object detection, and multi-object interaction.

Interpretability and attention to visualization

  • Better interpretability: Thanks to the self-attention mechanism, we can easily visualize the regions the model focuses on when making decisions, which increases the interpretability of the model.

sequence to sequence task

  • Easier to handle sequence output: In tasks like image captioning, it becomes more straightforward to consider both image and text information, since both can be handled with a similar Transformer architecture.

adaptability

  • Easier to adapt to different scales and shapes: Transformer does not rely on fixed-size filters, so it is theoretically easier to adapt to a wide variety of inputs.

1. In-depth Transformer

1.1 The origin of Transformer: a breakthrough in the field of NLP

The Transformer model was originally proposed by Google researchers in the 2017 paper "Attention Is All You Need". This model introduces a new architecture, mainly based on the Self-Attention mechanism, and successfully solves a series of tasks in natural language processing (NLP) at that time. Here are some important breakthroughs and influences of Transformer in the field of NLP:

1. A new perspective on sequence modeling problems
Traditional RNN (Recurrent Neural Network) and LSTM (Long Short Term Memory) networks encounter the problem of gradient disappearance or gradient explosion when processing long sequences due to their recursive nature. Transformer passedself-attention mechanismIt successfully captures the dependencies inside the sequence and is able to process the entire sequence in parallel, thus surpassing RNN and LSTM in many aspects.

2. Self-attention mechanism
The self-attention mechanism in the Transformer model allows the model to establish direct dependencies between inputs at different locations, which makes it easier for the model to understand the contextual relationship within a sentence or document. This mechanism is especially suitable for tasks such as machine translation, text summarization, question answering systems, etc. that need to capture long-distance dependencies.

3. Scalability
due to itsparallelismWith relatively less time complexity, the Transformer architecture makes more efficient use of modern hardware. This allows researchers to train larger, more powerful models that achieve better performance.

4. The architecture of the Transformer for multi-modal and multi-task learning
is highly flexible and can be easily extended to other types of data and tasks, includingImage, audio and multimodal input. This point has been widely confirmed in subsequent research and application.

5. Pre-training and fine-tuning
The Transformer architecture is suitable for pre-training and fine-tuning workflows. Large pre-trained models such as BERT, GPT, and T5 are built on Transformer and set new performance benchmarks on a variety of NLP tasks.

1.2 Basic composition of Transformer

1.2.1 Self-Attention Mechanism

psychologically speaking

  • Animals need to effectively focus on noteworthy points in complex environments
  • Psychological framework: Humans choose attention points based on volitional and involuntary cues (Note: The casual here does not mean casually, because it is translated, the casual here should mean active observation and inactive observation, and can also be understood as deliberate and unintentional

Imagine we have five objects in front of us: a newspaper, a research paper, a cup of coffee, a notebook, and a book. All paper products are printed in black and white except for the coffee cups which are red. In other words, this coffee mug is prominent and conspicuous in this visual environment, involuntarily attracting people's attention. So we put the sharpest eyesight on the coffee
insert image description here

And wanting to read has become a random clue
insert image description here

attention mechanism

  • existTraditional CNN Architecturemiddle. Convolution, pooling, and fully connected layers only consider involuntary clues
  • The attention mechanism is then shown to consider random cues
    • Random clues are called queries
    • Each input is a pair of value (value) and non-random clue (key) (Here the input can be understood as the environment
    • Some inputs are biased through the attention pooling layer , because we have added some random clues, and we can bias some inputs in it.

calculation process

  1. Dot product calculation: For a given query, perform a dot product with each key to measure the similarity between the query and each key.
  2. Scale: Scales the result of the dot product (usually by dividing by the square root of the key vector dimension).
  3. Activation function: Apply the Softmax activation function so that the sum of weights is 1 and between 0 and 1.
  4. Weighted sum: Performs a weighted sum of the vector of values ​​using the resulting weights.
  5. Output: Transform the weighted sum through an optional fully-connected (Linear) layer, producing the output at that position.

Multi-Head Attention (Multi-Head Attention)
In order to capture different dependencies more abundantly, multi-head attention is usually used. In multi-head attention, the model maintains multiple sets of independent queries, weight matrices for keys and values, and computes them in parallel. The output of each head is concatenated and combined through a fully connected layer.

1.2.2 Feed-forward Neural Networks

Feed-forward Neural Networks (FFNNs) are the earliest and simplest neural network architecture. The characteristic of this kind of network is that the data propagates in only one direction in the network: from the input layer, through the hidden layer, and finally to the output layer. This one-way flow of data is where the name "feedforward" comes from.

Structure and Components

  1. Input Layer: This layer receives raw input data and passes it on to the next layer.
  2. Hidden Layers: A network can contain one or more hidden layers, each consisting of multiple neurons. These layers capture the complex patterns of the input data.
  3. Output Layer: According to the requirements of the task (such as classification, regression, etc.), the output layer generates the final output of the network.

Activation functions
To introduce non-linearity, each neuron usually has an activation function. Commonly used activation functions are:

  • ReLU (Rectified Linear Unit)
  • Sigmoid
  • Tanh (Hyperbolic Tangent)
  • Leaky ReLU, Parametric ReLU, etc.

Training
Feedforward neural networks are usually trained using the Backpropagation algorithm, which involves:

  1. Forward Propagation: Starting from the input layer, data flows through the network to generate predicted outputs.
  2. Loss Calculation: Calculate the loss based on the predicted output and the actual target.
  3. Backward Propagation: Calculate the gradient of the loss with respect to each weight and update the weights in the network.

Application in Transformer
Although the Transformer architecture mainly focuses on the self-attention mechanism, it has a feed-forward neural network (usually a two-layer network) after each attention module. This introduces additional computing power to the model and helps capture different features of the data.

1.2.3 Residual Connections

In the Transformer architecture, the residual connection plays a very critical role. They appear after Self-Attention layers and Feed-forward Neural Networks layers, and are often used together with Layer Normalization.

Structure and function
In Transformer, the output of each sublayer (such as multi-head self-attention or feed-forward neural network) is added to the input of this sublayer to form a residual connection. This connection structure can be expressed as:

Output=Sublayer(x)+x
or more generally:
Output=LayerNorm(Sublayer(x)+x)

Here Sublayer(x) is the output of a sublayer (such as multi-head self-attention or feed-forward neural network), and LayerNorm is layer normalization.

1.2.4 Layer Normalization

Rationale
The core idea of ​​layer normalization is to normalize each sample in each layer independently so that the output of each layer has roughly the same scale. Layer normalization is applied after fully connected or convolutional layers, but usually before activation functions.
Mathematically expressed as:
insert image description here

Application in Transformer In
the Transformer architecture, layer normalization is usually used in conjunction with residual connections (Residual Connections). Each residual connection is followed by a layer normalization step to stabilize model training. This combination helps the model maintain numerical stability during training, especially for very deep models.

class AddNorm(nn.Module):
    """残差连接后进行层规范化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

advantage

  1. Numerical Stability: Layer normalization helps prevent the vanishing or exploding gradient problem, making the model easier to train.
  2. Accelerated Convergence: By adjusting the scale of each layer, layer normalization can speed up the convergence of the model.
  3. Adaptability: Layer normalization is applicable to network architectures of different types and depths, including recurrent neural networks (RNNs).

shortcoming

  1. Sequence length dependence: Layer normalization may not be as effective as Batch Normalization when dealing with variable length sequences.
  2. Model Complexity: Additional learnable parameters are introduced, which may increase the complexity of the model.

2. From CNN to Vision Transformer

Both Convolutional Neural Networks (CNN) and Vision Transformer (ViT) are popular models for image processing tasks, but they have different design philosophies and application scopes. The evolution between the two is briefly described below.

2.1 Limitations of CNNs

1. Local receptive field
CNN processes images through local receptive fields, which is a limitation in some tasks. While this design is helpful for identifying local structures in images, it may not be suitable for capturinglong distancedependencies.

2. Calculate the cost
when processinghigh resolution image, the computational cost of the convolution operation can be very high.

3. Spatial structure assumption
CNN assumes that the input data has someinherent spatial or temporal structure. This makes CNNs not easy to apply to data without clear spatial structure.

4. Parameter efficiency
In terms of parameter efficiency, even if various techniques (such as batch normalization, residual connection, etc.) are used, CNN may still be inferior to the Transformer model.

2.2 Emergence and motivation of Vision Transformer

Vision Transformer was first proposed by Google Research in 2020, and its design was inspired by the Transformer model for natural language processing.

1. Global attention
Unlike CNN, ViT uses a global self-attention mechanism, which can better handle thelong distance dependencies

2. Computational efficiency
ViT passself-attentionandfeedforward neural networkto achieve computational efficiency, especially when dealing with high-resolution images.

3. Modularity and scalability
ViT has good modularity and scalability, and can easily adjust the size and complexity of the model.

4. Parameter efficiency
is performed on a large number of data setspre-trainingFinally, ViT usually exhibits a high degree of parameter efficiency, that is, it performs better than CNN with the same number of parameters.

5. Cross-modal application
Since ViT does notHardcoded Spatial Assumptions, it is also easier to apply to other types of data and tasks.

3. How the Vision Transformer works

3.1 Input: Split the image into patches

Input: split the image into patches

  1. Image Segmentation: Vision Transformer (ViT) first divides the input image into multiplefixed-size chunks(patches). These small pieces are usuallysquare, for example 16x16 pixels.
  2. One-dimensionalization: each small block is flattened into a1D vector
  3. merge: all these 1D vectors are thenconcatenated into a sequence, as the input of the Transformer encoder.

3.2 Embedding: linear embedding and position embedding

  1. Linear Embedding: Small blocks are embedded through a linear layer (usually a fully connected layer) to transform them into vectors of suitable dimensions. This is equivalent to passing a very shallowCNN layerPerform feature extraction.
  2. Positional embedding: due to the smallOriginal position information is lost during 1D, so position embeddings need to be added to help the model identify the relative or absolute positions of these patches.
  3. merge: Linear Embeddings and Positional Embeddingsare usually added together to generate an embedding sequence that includes positional information.

3.3 Transformer encoder

  1. Self-attention layer: This layer uses a self-attention mechanism to analyze each element in the input sequence (that is, each patch and its corresponding position embedding) in order to better represent the relationship between the various patches.
  2. Feed-forward Neural Network: The output of the self-attention layer is fed into a Feed-forward Neural Network.
  3. Residual connection and layer normalization: After the self-attention layer and feed-forward neural network, there will be residual connection and layer normalization operations to promote the stability and efficiency of model training.
  4. Stacked encoder: All the above components are stacked multiple times (eg, 12 or 24 times, etc.) to form a complete Transformer encoder.
  5. Classification Head: For classification tasks, it is common to take the first element of the encoder output sequence (usually corresponding to a special "[CLS]" token) and pass it through a fully connected layer for classification.
class EncoderBlock(nn.Module):
    """Transformer编码器块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout,
            use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(
            ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

No layer in the Transformer encoder changes the shape of its input.

3.4 Output Head: Classification Task

In the Vision Transformer (ViT) model, the output head for classification tasks is usually aFully connected (linear) layer, a layer that maps the output of the Transformer encoder to the number of class labels. In most implementations, Transformer is usually usedEncoderoutputFeatures in the first position (usually corresponding to the special [CLS] tag added)

4. Variants of ViT and related work

With the success of Vision Transformer (ViT) in image classification tasks, many researchers began to explore its variants and improvements. Here are some noteworthy variants and related work for an overview analysis:

4.1 DeiT (Data-efficient Image Transformer)

4.1.1 Overview

  • Concept: DeiT focuses on how to use data more effectively. Standard ViT requires a lot of data and computing resources for pre-training, but DeiT adopts a more efficient training strategy, especiallydata augmentationandknowledge distillation, to improve this.
  • Main features: useknowledge distillationand different training techniques such asLearning Rate Scheduling and Data Augmentation, to reduce the dependence on a large amount of labeled data.
import torch
import torch.nn as nn
import torch.nn.functional as F

# 分割图像到patch
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, C, H, W]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        return x

# DeiT 模型主体
class DeiT(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim, num_heads, num_layers, num_classes):
        super().__init__()

        # 分割图像到patch并嵌入
        self.patch_embed = PatchEmbedding(patch_size, in_channels, embed_dim)

        # 特殊的 [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # 位置嵌入
        num_patches = (224 // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # Transformer 编码器
        encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

        # 分类器头
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)

        # 分割图像到patch并嵌入
        x = self.patch_embed(x)

        # 添加 [CLS] token
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)

        # 添加位置嵌入
        x += self.pos_embed

        # 通过 Transformer
        x = self.transformer(x)

        # 只取 [CLS] 对应的输出用于分类任务
        x = x[:, 0]

        # 分类器
        x = self.fc(x)

        return x

# 参数
patch_size = 16
in_channels = 3
embed_dim = 768
num_heads = 12
num_layers = 12
num_classes = 1000  # 假设是一个1000分类问题

# 初始化模型
model = DeiT(patch_size, in_channels, embed_dim, num_heads, num_layers, num_classes)

# 假数据
x = torch.randn(32, 3, 224, 224)  # 32张3通道224x224大小的图片

# 模型前向推断
logits = model(x)

4.1.2 Knowledge Distillation

Knowledge Distillation (KD) is a model compression technique for transferring knowledge from a large, complex model (often referred to as the "teacher model") to a smaller, simpler model (often referred to as the "student model"). model"). The goal is to reduce model size and inference time while maintaining similar performance to larger models.

working principle

  • Teacher model: usually a pre-trainedlarge model, used to generatesoft label(soft labels), that is, the category probability distribution.
  • Student model: usually a relatively small model that needs to be trained toimitate the teacher model
  • Distillation loss: In the most basic knowledge distillation, the training of the student model not only requiresMinimize the loss against the true labels (such as cross-entropy loss), but also minimize the loss against the soft labels predicted by the teacher model

Simple knowledge distillation code example
Suppose we have a teacher model (teacher_model) and a student model (student_model), the following is a simple example of knowledge distillation using PyTorch:

import torch
import torch.nn.functional as F

# 假定 teacher_model 和 student_model 已经定义并初始化
# teacher_model = ...
# student_model = ...

# 数据加载器
# data_loader = ...

# 优化器
optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

# 温度参数和软标签权重
temperature = 2.0
alpha = 0.9

# 训练循环
for data, labels in data_loader:
    optimizer.zero_grad()

    # 正向传播:教师和学生模型
    teacher_output = teacher_model(data).detach()  # 注意:通常不会计算教师模型的梯度
    student_output = student_model(data)

    # 计算损失
    hard_loss = F.cross_entropy(student_output, labels)  # 与真实标签的损失
    soft_loss = F.kl_div(F.log_softmax(student_output/temperature, dim=1),
                         F.softmax(teacher_output/temperature, dim=1))  # 与软标签的损失

    loss = alpha * soft_loss + (1 - alpha) * hard_loss

    # 反向传播和优化
    loss.backward()
    optimizer.step()

Application scenarios
Knowledge distillation is not only suitable for model compression, but also can be used to improve the performance of small models in some specific applications, such as in DeiT (Data-efficient Image Transformer) to improve data efficiency.

4.1.3 Transformer model optimized by knowledge distillation

Below we assume that there is a large Transformer model (teacher model) that has been trained, and a smaller Transformer model (student model).

Note: For simplicity here, we use the nn.Transformer module as a simple implementation of Transformer. You can also replace with more complex models as needed.

The loss function consists of two parts: one is the loss between the student model and the actual labels, and the other is the Kullback-Leibler divergence between the student and teacher model outputs.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# 定义简单的 Transformer 模型
class SimpleTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, num_classes):
        super(SimpleTransformer, self).__init__()
        self.encoder = nn.Transformer(d_model, nhead, num_layers)
        self.classifier = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        x = self.encoder(x)
        x = x.mean(dim=1)
        x = self.classifier(x)
        return x

# 定义损失函数
def distillation_loss(y, labels, teacher_output, T=2.0, alpha=0.5):
    return nn.CrossEntropyLoss()(y, labels) * (1. - alpha) + (alpha * T * T) * nn.KLDivLoss()(F.log_softmax(y/T, dim=1),
                                                     F.softmax(teacher_output/T, dim=1))

# 假设我们有一些数据
# 注意:这里使用随机数据仅作为示例
N = 100  # 数据点数量
d_model = 32  # 嵌入维度
nhead = 2  # 多头注意力的头数
num_layers = 2  # Transformer 层的数量
num_classes = 10  # 分类数
T = 2.0  # 温度参数
alpha = 0.5  # 蒸馏损失的权重因子

x = torch.randn(N, 10, d_model)
labels = torch.randint(0, num_classes, (N,))

# 初始化教师和学生模型
teacher_model = SimpleTransformer(d_model, nhead, num_layers, num_classes)
student_model = SimpleTransformer(d_model, nhead, num_layers, num_classes)

# 设置优化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# 模拟训练过程
for epoch in range(10):
    # 前向传播
    teacher_output = teacher_model(x).detach()  # 通常来说,教师模型是预先训练好的,因此不需要计算梯度
    student_output = student_model(x)
    
    # 计算损失
    loss = distillation_loss(student_output, labels, teacher_output, T, alpha)
    
    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {
      
      epoch+1}, Loss: {
      
      loss.item()}")

4.2 Hybrid models (ViT + CNN)

Hybrid models combine the advantages of Vision Transformer (ViT) and Convolutional Neural Network (CNN) to achieve more powerful image recognition capabilities. Such models usually use a CNN as a feature extractor whose output is used as input to ViT.

4.2.1 Why Mixed Models?

  1. Local and global features: CNNs are very good at capturinglocal characteristics, and the Transformer can handleglobal dependencies. Combining the two allows for a more complete understanding of the image.
  2. Computational Efficiency: CNNs are generally more efficient at processing image data. through theThe front end of the model uses CNN, which can reduce the computational complexity of Transformer.
  3. Data Efficiency: Using CNN'sPre-trained features can improve the data efficiency of the model, which is especially useful for tasks with less training data.

4.2.2 Infrastructure

In a typical hybrid model, CNN is usually used as feature extractor, while ViT is used as feature encoding and classification.

  1. Feature extraction: Use a CNN layer (probably a pretrained network such as ResNet or VGG) to extract features from the input image.
  2. Image block and embedding: block the output of CNN and convert it into a sequence suitable for Transformer through a linear embedding layer (or other methods).
  3. Transformer encoding: further encoding of features using ViT.
  4. Classification Head: Finally, a fully connected layer is used for classification.

4.2.3 Examples

import torch
import torch.nn as nn

# 假设使用 ResNet 的某个版本作为特征提取器
class FeatureExtractor(nn.Module):
    def __init__(self, ...):
        super().__init__()
        # 定义 CNN 结构,例如一个简化的 ResNet
        ...

    def forward(self, x):
        # 通过 CNN 提取特征
        return x

# ViT 作为编码器
class ViTEncoder(nn.Module):
    def __init__(self, ...):
        super().__init__()
        # 定义 Transformer 结构
        ...

    def forward(self, x):
        # 通过 Transformer 编码特征
        return x

# 混合模型
class HybridModel(nn.Module):
    def __init__(self, ...):
        super().__init__()
        self.feature_extractor = FeatureExtractor(...)
        self.vit_encoder = ViTEncoder(...)
        self.classifier = nn.Linear(...)

    def forward(self, x):
        x = self.feature_extractor(x)  # CNN 特征提取
        x = self.vit_encoder(x)  # Transformer 编码
        x = self.classifier(x)  # 分类头
        return x

4.3 Swin Transformer

Swin Transformer is a Transformer architecture for computer vision tasks, which proposes a self-attention mechanism based on sliding windows. This approach combines the advantages of Convolutional Neural Networks (CNN) and Transformers, aiming to achieve higher model efficiency and performance.

4.3.1 Main Features

  1. Hierarchical Feature Extraction: Similar to CNN, Swin Transformer performsMulti-Layer Feature Extraction, each layer will downsample, but here it is implemented through Transformer.
  2. Sliding window self-attention: Swin Transformer usesSliding window self-attention mechanism, the mechanism only considerslocal context information, instead of the traditional Transformerglobal context information. thisReduced computational complexity
  3. Chunking and merging: at multiple levels, the Swin Transformer passesChunking and mergingway, step by stepreduce the length of the sequence,andIncrease feature dimension, to achieve higher level feature extraction.
  4. Flexibility: Swin Transformer can be used for a variety of computer vision tasks, such as image classification, object detection, and semantic segmentation.

4.3.2 Infrastructure

  1. Patch Embedding: Divide the image into multiple small patches (patches), and then uselinear embedding layerto embed.
  2. Swin Transformer Blocks: Consists of multiple Swin Transformer layers, each with one or more sliding window self-attention mechanisms and feed-forward neural networks.
  3. Head: According to specific tasks (such as classification, detection, etc.), in the Swin Transformerlast layerAdd different head structures.

4.3.3 Code example

  • PatchEmbedding: This part is responsible for cutting the input image into small pieces and embedding them.
  • WindowAttention: This is specific to Swin Transformer and is used for self-attention calculations within the local window.
  • SwinBlock: Consists of a window attention layer and a multi-layer perceptron (MLP).
  • SwinTransformer: The final model architecture.
import torch
import torch.nn as nn
import torch.nn.functional as F

# 切分图像为patches
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, out_dim, patch_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.conv(x)
        x = x.flatten(2).transpose(1, 2)
        return x

# 滑窗注意力
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, window_size):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.window_size = window_size

        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)

    def forward(self, x):
        # 假设 x 的形状为 [batch_size, num_patches, dim]
        # 分割为多个窗口
        windows = x.view(x.size(0), self.window_size, self.window_size, self.dim)

        # 计算 q, k, v
        q = self.query(windows)
        k = self.key(windows)
        v = self.value(windows)

        # 注意力计算
        attn = torch.einsum('bqhd,bkhd->bhqk', q, k)
        attn = F.softmax(attn, dim=-1)

        # 输出
        out = torch.einsum('bhqk,bkhd->bqhd', attn, v)
        out = out.contiguous().view(x.size(0), self.window_size * self.window_size, self.dim)

        return out

# Swin Transformer Block
class SwinBlock(nn.Module):
    def __init__(self, dim, heads, window_size):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, heads, window_size)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# Swin Transformer 模型
class SwinTransformer(nn.Module):
    def __init__(self, in_channels, out_dim, patch_size, num_classes):
        super().__init__()
        self.patch_embedding = PatchEmbedding(in_channels, out_dim, patch_size)

        # 假设我们有 4 个 Swin Blocks 和窗口大小为 8
        self.blocks = nn.ModuleList([
            SwinBlock(out_dim, 8, 8) for _ in range(4)
        ])

        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(out_dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        for block in self.blocks:
            x = block(x)
        x = self.global_avg_pool(x.mean(dim=1))
        x = self.fc(x.squeeze(-1))
        return x

# 测试模型
if __name__ == '__main__':
    model = SwinTransformer(3, 128, 4, 10)
    x = torch.randn(16, 3, 32, 32)  # 假设有 16 张 32x32 的图像
    y = model(x)
    print(y.shape)  # 应该输出 torch.Size([16, 10])

5. Advantages and disadvantages of ViT

5.1 Advantages compared with CNN

  1. Better long-distance dependency handling: Transformer architecture was originally designed to capturelong-distance dependence, which is very useful in some complex image recognition tasks.
  2. Parameter efficiency: ViT has potentialAchieve the same performance as CNN with fewer parameters
  3. Interpretability: The output of the self-attention mechanismcan be used to analyze the model for each part of the imageConcern, which is helpful for model interpretation.
  4. Flexibility and generalization: Transformer does not depend onFixed size filters or local regions, thus having the potential to generalize better to different types and structures of visual data.
  5. End-to-end training: Compared with some specially designed CNN architectures, ViT can be trained from the beginning to the end with a unified architecture.

5.2 Challenges and limitations of ViT

  1. Computational complexity: forlarge image, the computational complexity of the global self-attention mechanism can be very high. This is one of the reasons why ViT was mainly used in the field of NLP at the beginning.
  2. Data dependent: ViT usually requiresA large amount of labeled data for effective training. This can be a problem in scenarios where there is not a lot of labeled data.
  3. Unstable training: Transformer architectures are usually faster than CNNsharder to train, especially in the absence of sufficient computing resources and data.
  4. Local feature processing is not as good as CNN: Since there is no built-in convolution operation, ViT may be in someTasks that rely on local features(such as texture recognition) is not as good as a specially designed CNN.
  5. Memory consumption: Especially on large images or long sequences, Transformer models (including ViT) oftenneed more memory
  6. Risk of overfitting: Due to the usually large model complexity and number of parameters, ViTmore prone to overfitting, especially when the amount of data is small.

Guess you like

Origin blog.csdn.net/m0_63260018/article/details/130999047