pytorch—implementing various attention

1.What is Attention?

The so-called Attention mechanism is a mechanism that focuses on local information, such as a certain image area in the image. As the task changes, areas of attention tend to change.

Faced with a picture like the one above, if you just look at it as a whole, you only see a lot of heads. But if you zoom in and take a closer look at each one, it will be amazing. They are all genius scientists.

The information in the picture except faces is actually useless and cannot do any tasks. The Attention mechanism is to find the most useful information. As you can imagine, the simplest scenario is to detect faces from photos.

The core focus of the attention mechanism is to let the network focus on where it needs more attention.

When we use convolutional neural networks to process images, we would prefer that the convolutional neural network pays attention to the places that should be paid attention to, rather than paying attention to everything. It is impossible for us to manually adjust the places that need attention. At this time, how to make It becomes extremely important for convolutional neural networks to adaptively pay attention to important objects.

The attention mechanism is a way to achieve network adaptive attention.

Generally speaking, attention mechanisms can be divided into channel attention mechanisms, spatial attention mechanisms, and a combination of the two.

2. How to implement the attention mechanism

2.1 Implementation of SENet

SENet is a typical implementation of the channel attention mechanism. Its specific implementation is:

1. Perform global average pooling on the input feature layer.

2. Then perform two full connections. The first time the number of fully connected neurons is smaller, and the second time the number of fully connected neurons is the same as the input feature layer.

3. After completing two full connections, we take Sigmoid again and fix the value to between 0 and 1. At this time, we obtain the weight of each channel of the input feature layer (between 0 and 1).

4. After obtaining this weight, we multiply this weight by the original input feature layer.

[External link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-Je48gT1Y-1688792310265)(/Users/zhangkai/Library/Application Support/typora-user-images/image-20230706113749078 .png)]

The implementation code is as follows:

import torch
import torch.nn as nn
import math

class se_block(nn.Module):
    def __init__(self, channel, ratio=16):
        super(se_block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // ratio, bias=False),
                nn.ReLU(inplace=True),
                nn.Linear(channel // ratio, channel, bias=False),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

2.2 CBAM implementation

CBAM combines the channel attention mechanism and the spatial attention mechanism , compared to SENet only focuses on the attention mechanism of the channel to achieve better results. The schematic diagram of its implementation is as follows. CBAM will process the channel attention mechanism and the spatial attention mechanism of the input feature layer respectively.

[External link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-PU4NN1pV-1688792310266)(/Users/zhangkai/Library/Application Support/typora-user-images/image-20230706114015842 .png)]

The following figure shows the specific implementation of the channel attention mechanism and spatial attention mechanism:

The upper part of the image is the channel attention mechanism. The implementation of the channel attention mechanism can be divided into two parts. We will perform global average pooling and global maximum pooling on the input single feature layer. Afterwards, the results of average pooling and max pooling are processed using the shared fully connected layer. We will add the two processed results and then take a sigmoid. At this time, we obtain each channel of the input feature layer. weight (between 0-1). After obtaining this weight, we multiply this weight by the original input feature layer.

The lower part of the image is the spatial attention mechanism. We will take the maximum value and average value on the channel of each feature point for the input feature layer. Afterwards, the two results are stacked, and a convolution with a channel number of 1 is used to adjust the number of channels, and then a sigmoid is taken. At this time, we obtain the weight of each feature point of the input feature layer (between 0-1) . After obtaining this weight, we multiply this weight by the original input feature layer.

[External link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-YxqttmZE-1688792310266)(/Users/zhangkai/Library/Application Support/typora-user-images/image-20230706114322179 .png)]

Implementation:

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # 利用1x1卷积代替全连接
        self.fc1   = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class cbam_block(nn.Module):
    def __init__(self, channel, ratio=8, kernel_size=7):
        super(cbam_block, self).__init__()
        self.channelattention = ChannelAttention(channel, ratio=ratio)
        self.spatialattention = SpatialAttention(kernel_size=kernel_size)

    def forward(self, x):
        x = x * self.channelattention(x)
        x = x * self.spatialattention(x)
        return x

2.3 Implementation of ECA

ECANet is also an implementation form of the channel attention mechanism. ECANet can be regarded as an improved version of SENet.
The author of ECANet believes that SENet’s prediction of the channel attention mechanism has brought side effects , Capturing all channel dependencies is inefficient and unnecessary.
In the ECANet paper, the author believes that convolution has good cross-channel information acquisition capabilities.

The idea of ​​the ECA module is very simple. It removes the fully connected layer in the original SE module and directly learns through a 1D convolution on the features after global average pooling.

When we use 1D convolution, we typically apply a convolution kernel to each position of the input sequence to produce an output sequence. Here is a simple example, assuming we have a 1D tensor of length 10 x x x, the convolution kernel size is 3, the stride is 1, and the filling method is "VALID", that is, no filling is performed. The convolution kernel weights are as follows:

W = [ 1 − 1 0.5 ] W = \begin{bmatrix}1 & -1 & 0.5\end{bmatrix} IN=[110.5]

Then, we can perform a 1D convolution operation on the input tensor in the following way:

  1. Slide the convolution kernel from left to right, moving one position at a time, to perform a convolution operation with a part of the input tensor.

  2. Store the result of the convolution in the corresponding location of the output tensor.

  3. Repeat steps 1 and 2 until the kernel slides to the end of the input tensor.

Specifically, we can use the following method to calculate each element in the output tensor:

y i = ∑ j = 0 2 W j x i + j y_i = \sum_{j=0}^{2}W_jx_{i+j} andi=j=02INjxi+j

In that, y i y_i andi is the i i in the output tensori single element, W j W_j INj is the convolution kernel’s j j j 个权力, x i + j x_{i+j} xi+j is the i + j i+j in the input tensori+j elements. Note that since we use the "VALID" padding method, the edge elements of the input tensor are not considered by the convolution kernel.

Here is a simple Python code example that demonstrates how to use PyTorch to implement a 1D convolution operation:

import torch
import torch.nn as nn

# 定义输入张量
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# 将输入张量转为 1D 卷积层的输入格式:[batch_size, in_channels, sequence_length]

# 定义卷积核
conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=0, bias=False)
conv.weight.data = torch.tensor([[[1, -1, 0.5]]], dtype=torch.float32)

# 进行 1D 卷积运算
y = conv(x)

# 输出结果
print(y)

The running results are as follows:

tensor([[[0.5000, 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000]]],
       grad_fn=<ConvolutionBackward0>)

It can be seen that each element in the output tensor is obtained by convolution operation with the input tensor through the convolution kernel. This is a simple example of 1D convolution, which can be applied to time series data, text data, etc.

As shown in the picture below, the left picture is the conventional SE module and the right picture is the ECA module. The ECA module replaces two full connections with 1D convolution.

Original link: https://blog.csdn.net/weixin_44791964/article/details/121371986

[External link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-9W1mNVdF-1688792310266)(/Users/zhangkai/Library/Application Support/typora-user-images/image-20230706154950304 .png)]

Specific code implementation:

import torch
import torch.nn as nn

class eca_block(nn.Module):
    def __init__(self, channel, b=1, gamma=2):
        super(eca_block, self).__init__()
        kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
        kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

3. Self-attention realization

3.1 Understanding QKV in Self-Attention

3.2 The understanding of self in Self-attention

3.3 Calculation process in Self-attention

Original link: https://blog.csdn.net/qq_37541097/article/details/117691873

Assume that the input sequence length is 2, and the input is two nodes x 1 x_1 x1, x 2 x_2 x2, and then map the input to $a_1, a_2 $ through Input Embedding, which is f(x) in the figure, and then $a_1, a_2 respectively through three transformations The matrix passes through three transformation matricesThree transformation matricesW_q, W_k, W_v$ (these three parameters are trainable and shared) are used to obtain the corresponding q i , k i , v i q^i , k^i , v^i qi,ki,ini
(This is implemented directly using the fully connected layer in the source code. For the convenience of understanding, paranoia is ignored here)

  • q represents query, and will be matched with each k later;

  • k represents the key, which will be matched by each q in the future;

  • v represents the information extracted from a.

  • The subsequent matching process of q and k can be understood as calculating the correlation between the two. The greater the correlation, the correspondingv’s weight will be greater.

假设 a 1 = ( 1 , 1 ) , a 2 = ( 1 , 0 ) a_1 = (1, 1), a_2 = (1, 0) a1=(1,1)a2=(1,0) W q = ( 1 1 0 1 ) W^q = \begin{pmatrix} 1 & 1 \\ 0 & 1 \\ \end{pmatrix} INq=(1011). Nana:

q 1 = ( 1 , 1 ) ( 1 1 0 1 ) = ( 1 , 2 ) q^1 = (1, 1) \begin{pmatrix} 1 & 1 \\ 0 & 1 \\ \end{pmatrix} = (1, 2) q1=(1,1)(1011)=(1,2)

q 2 = ( 1 , 0 ) ( 1 1 0 1 ) = ( 1 , 1 ) q^2 = (1, 0) \begin{pmatrix} 1 & 1 \\ 0 & 1 \\ \end{pmatrix} = (1, 1) q2=(1,0)(1011)=(1,1)

As mentioned earlier, Transformer can be parallelized, so it can be written directly:

( q 1 q 2 ) = ( 1 1 1 0 ) ( 1 1 0 1 ) = ( 1 2 1 1 ) \begin{pmatrix}q^1\\q^2\end{pmatrix}= \begin{pmatrix}1 & 1 \\ 1 & 0 \end{pmatrix} \begin{pmatrix}1 & 1 \\ 0 & 1 \end{pmatrix} = \begin{pmatrix}1 & 2 \\ 1 & 1 \end{pmatrix} (q1q2)=(1110)(1011)=(1121)

Similarly we can get ( k 1 k 2 ) \begin{pmatrix}k^1\\ k^2\end{pmatrix} (k1k2) ( v 1 v 2 ) \begin{pmatrix}v^1\\ v^2\end{pmatrix} (in1in2), then the obtained ( q 1 q 2 ) \begin{pmatrix}q^1\\q^2\end {pmatrix} (q1q2) is Q in the original paper, ( k 1 k 2 ) \begin{pmatrix}k^1\\ k^2 \end{pmatrix} (k1k2)K, ( v 1 v 2 ) \begin{pmatrix}v^1\\ v^2\end{pmatrix } (in1in2)is V.

Then take it first q ​​1 q^1 q1 to match each k, dot multiplication operation, and then divide to get d \sqrt{ d} d to the corresponding α, where d d dRepresentative direction quantity k i k^i kThe length of i, which in this example is equal to 2, divided by d d The reason for d is explained in the paper as "the value after dot multiplication is very large, which causes the gradient to become very small after passing softmax", so by dividing byd d \sqrt{d} d to zoom.

Comparison calculation α 1 , i α_{1,i} a1,i

α 1 , 1 = q 1 ⋅ k 1 / d = 1 × 1 + 2 × 0 / 2 = 0.71 α_{1,1} = q^1·k^1/\sqrt{d} = 1×1+2 ×0/2 = 0.71a1,1=q1k1/d =1×1+2×0/2=0.71

α 1 , 2 = q 1 ⋅ k 2 / d = 1 × 0 + 2 × 1 / 2 = 1.41 α_{1,2} = q^1·k^2/d = 1×0+2×1/2 = 1.41a1,2=q1k2/d=1×0+2×1/2=1.41

Similarly take q ​​2 q^2 q2Leaving and possessing k ability to get α 2 , i α_{2,i} a2,i, written uniformly in matrix multiplication form:

( α 1 , 1 α 1 , 2 α 2 , 1 α 2 , 2 ) = ( q 1 q 2 ) ( k 1 k 2 ) d \begin{pmatrix} α_{1,1} & α_{1,2} \\ α_{2,1} & α_{2,2}\end{pmatrix} = \frac {\begin{pmatrix}q^1 \\ q^2\end{pmatrix} \begin{pmatrix}k^1 & k^2\end{pmatrix}}{\sqrt{d}}(a1,1a2,1a1,2a2,2)=d (q1q2)(k1k2)

可以对对行即 ( α 1 , 1 , α 1 , 2 ) 和 ( α 2 , 1 , α 2 , 2 ) (α_{1,1}, α_{1,2})和(α_{2,1}, α_{2,2}) (α1,1,a1,2)sum(α2,1,a2,2) are subjected to softmax processing respectively to obtain ( α ^ 1 , 1 , α ^ 1 , 2 ) (\hat{α}_{ 1,1},\hat{α}_{1,2}) (a^1,1,a^1,2) ( α ^ 2 , 1 , α ^ 2 , 2 ) (\hat{α}_{2,1} , \hat{α}_{2,2}) (a^2,1,a^2,2),这り的 a ^ \hat{a} a^ is equivalent to calculating the weight for each v. At this point we have completed the Attention(Q, K, V) formula s o f t m a x ( Q K T / d k ) softmax(QK^T/\sqrt{d_k}) softmax (QKT/dk ) part.

α has been calculated above, that is, the weight for each v, and then weighted to get the final result:
b 1 = α ^ 1 , 1 × v 1 + α ^ 1 , 2 × v 2 = ( 0.33 , 0.67 ) , b 2 = α ^ 2 , 1 × v 1 + α ^ 2 , 2 × v 2 = ( 0.50 , 0.50 ) \begin{aligned} b_1 &= \hat{ \alpha}_{1, 1} \times v^1 + \hat{\alpha}_{1, 2} \times v^2=(0.33, 0.67) \quad ,\quad b_2 = \hat{\alpha }_{2, 1} \times v^1 + \hat{\alpha}_{2, 2} \times v^2=(0.50, 0.50) \end{aligned} b1=a^1,1×in1+a^1,2×in2=(0.33,0.67),b2=a^2,1×in1+a^2,2×in2=(0.50,0.50)

Specifies the infinitive:
( b 1 b 2 ) = ( α ^ 1 , 1 α ^ 1 , 2 α ^ 2 , 1 α ^ 2 , 2 ) ( v 1 v 2 ) \begin {pmatrix} b_1 \\b_2 \end{pmatrix} = \begin{pmatrix} \hat\alpha_{1, 1} & \hat\short_{1, 2} \\ \hat\short_{2, 1} & \hat\alpha_{2, 2} \end{pmatrix} \begin{pmatrix} v^1 \\v^2 \end{pmatrix} (b1b2)=(a^1,1a^2,1a^1,2a^2,2)(in1in2)
At this point, the content of Self-Attention is finished. To sum it up, this is a formula in the paper:

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax(\frac{QK^T} {\sqrt{d_k}})V Attent ion(Q ,K,V)=softmax (dk QKT)V

Among them, Q, K, and V are calculated from the input sequence. The softmax function is used to calculate the weight corresponding to each position. The final output is the weighted sum of V, and the weight is the output of the softmax function. This formula is a core component of the Transformer model and is widely used in natural language processing and other sequence data processing tasks.

import torch
import torch.nn as nn


class Self_Attention(nn.Module):
    def __init__(self, dim, dk, dv):
        super(Self_Attention, self).__init__()
        self.scale = dk ** -0.5  # 公式里的根号dk
        self.q = nn.Linear(dim, dk)
        self.k = nn.Linear(dim, dk)
        self.v = nn.Linear(dim, dv)  # v的维度不需要和q,k一样

    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim = -1)

        x = attn @ v

        return x

att = Self_Attention(dim=2, dk=2, dv=3)
x = torch.rand((1, 4, 2))  # 1 是batch_size 4是token数量 2是每个token的长度
print(x)
output = att(x)

3.4 Muti-Head Attention

mg src="/Users/zhangkai/Library/Application Support/typora-user-images/image-20230708120201377.png" alt="image-20230708120201377" style="zoom: 33%;" />

First of all, it is the same as the Self-Attention module. a i a_i aiSeparate communication W q W^q INq W k W^k INk W v W^v INvobtainable q i q^i qi k i k^i ki v i v^i ini, and then based on the number of heads used h h hto get it in one step q i q^i qi k i k^i ki v i v^i iniequal component h h h份. As a comparison, h = 2 h=2 h=2, afterward q 1 q^1 q1拆Component q 1 , 1 q^{1,1} q1,1 q 1 , 2 q^{1,2} q1,2,那么 q 1 , 1 q^{1,1} q1,1head1, q 1 , 2 q^{1,2} q1,2belonging to head2.

Seeing this, people who have read the original paper must have doubts. It is not written in the paper. Pass W i Q W^Q_i INiQ W i K W^K_i INiK W i V W^V_i INiVMap to get each head's Q i Q_i Qi K i K_i Ki V i V_i INi:

h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) head_i = {\rm Attention}(QW^Q_i, KW^K_i, VW^V_i) headi=Attention(QWiQ,KWiK,VWiV)

But I saw some source codes on GitHub that were simply divided equally. In fact, W i Q W^Q_i INiQ W i K W^K_i INiK W i V W^V_i INiV is set to the corresponding value to achieve equal distribution, such as Q Q in the figure belowQcommunication W 1 Q W^Q_1 IN1QYou can get the equally divided Q 1 Q_1 Q1

You can get each through the above method h e a d i head_i headi对应的 Q i Q_i Qi K i K_i Ki V i V_i INiparameters, and then use the same method as in Self-Attention for each head to get the corresponding results.

A t t e n t i o n ( Q i , K i , V i ) = s o f t m a x ( Q i K i T d k ) V i {\rm Attention}(Q_i, K_i, V_i)={\rm softmax}(\frac{Q_iK_i^T}{\sqrt{d_k}})V_i Attention(Qi,Ki,INi)=softmax(dk QiKiT)Vi

In that, Q i Q_i Qi K i K_i Ki V i V_i INi is calculated from the input sequence. The softmax function is used to calculate the weight corresponding to each position. The final output is V i V_i INiThe weighted sum of , the weight is the output of the softmax function.

Then concat the results obtained by each head, such as in the figure below b 1 , 1 b_{1,1} b1,1 h e a d 1 head_1 head1obtainable b 1 b_1 b1)sum b 1 , 2 b_{1,2} b1,2 h e a d 2 head_2 head2obtainable b 1 b_1 b1) spliced ​​together, b 2 , 1 b_{2,1} b2,1 h e a d 1 head_1 head1obtainable b 2 b_2 b2)sum b 2 , 2 b_{2,2} b2,2 h e a d 2 head_2 head2obtainable b 2 b_2 b2) spliced ​​together.

Then pass the spliced ​​result W O W^O INO (learnable parameters) are fused, as shown in the figure below. After fusion, the final result is obtained b 1 , b 2 b_1 , b_2 b1,b2

Code:

import torch  # 导入PyTorch库
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()  # 继承自nn.Module基类
        self.n_heads = n_heads  # 多头注意力头数
        self.d_model = d_model  # 输入向量维度
        self.d_k = d_model // n_heads  # 每个头的维度
        self.dropout = nn.Dropout(p=dropout)  # dropout概率

        # 初始化Query、Key、Value的权重矩阵
        self.W_q = nn.Linear(d_model, n_heads * self.d_k)  # Query权重矩阵
        self.W_k = nn.Linear(d_model, n_heads * self.d_k)  # Key权重矩阵
        self.W_v = nn.Linear(d_model, n_heads * self.d_k)  # Value权重矩阵

        # 初始化输出的权重矩阵
        self.W_o = nn.Linear(n_heads * self.d_k, d_model)  # 输出向量的权重矩阵

    def forward(self, x, mask=None):
        # 输入 x 的维度为 [batch_size, seq_len, d_model]
        batch_size, seq_len, d_model = x.size()

        # 通过权重矩阵计算 Q、K、V
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k)

        # 交换维度以便于计算注意力权重
        Q = Q.permute(0, 2, 1, 3).contiguous().view(batch_size * self.n_heads, seq_len, self.d_k)
        K = K.permute(0, 2, 1, 3).contiguous().view(batch_size * self.n_heads, seq_len, self.d_k)
        V = V.permute(0, 2, 1, 3).contiguous().view(batch_size * self.n_heads, seq_len, self.d_k)

        # 计算注意力权重
        scores = torch.bmm(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = nn.Softmax(dim=-1)(scores)
        attn_weights = self.dropout(attn_weights)

        # 计算输出向量
        attn_output = torch.bmm(attn_weights, V)
        attn_output = attn_output.view(batch_size, self.n_heads, seq_len, self.d_k)
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len,
                                                                        self.n_heads * self.d_k)
        output = self.W_o(attn_output)

        return output


# 定义输入向量
x = torch.randn(2, 10, 128)

# 定义注意力模块
attn = MultiHeadAttention(n_heads=8, d_model=128)

# 进行前向传播计算
output = attn(x)

# 打印输出向量的形状
print(output.shape)  # 输出:torch.Size([2, 10, 128])

4. Vision Transformer

Vision Transformer (ViT) is a Transformer-based image classification model. The figure below is the ViT model framework given in the original paper.

Simply put, the ViT model consists of three modules:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder (a more detailed structure is given on the right side of the figure)
  • MLP Head (the layer structure ultimately used for classification)

Detailed explanation of Embedding layer structure

For the standard Transformer module, the required input istoken (vector) sequence, that is, the two-dimensional matrix [num_token, token_dim], as shown in the figure below, each token corresponds to a vector, taking ViT-B/16 as an example, the length of each token vector is 768.

For image data, the data format is [H, W, C], which is a three-dimensional matrix and is obviously not what the Transformer wants. so neededFirst transform the data through an Embedding layer. As shown in the figure below, first divide an image into a bunch of Patches according to a given size. Taking ViT-B/16 as an example, divide the input image (224x224) into patches of 16x16 size. After division, you will get ( 224 / 16 ) 2 = 196 (224/16)^ 2=196 (224/16)2=196 Patches. Then each patch is mapped into a one-dimensional vector through linear mapping. Taking ViT-B/16 as an example, the data shape of each patch is [16, 16, 3] Obtain a vector of length 768 through mapping (hereinafter referred to as token directly). [ 16 , 16 , 3 ] → [ 768 ] [16, 16, 3] \rightarrow [768] [16,16,3][768]

In the code implementation, it is implemented directly through a convolutional layer. Taking ViT-B/16 as an example, it is implemented directly by using a convolution with a convolution kernel size of 16x16, a stride of 16, and a number of convolution kernels of 768. By convolution [ 224 , 224 , 3 ] → [ 14 , 14 , 768 ] [224, 224, 3] \rightarrow [14, 14, 768] [224,224,3][14,14,768], and then flatten the H and W dimensions [ 14 , 14, 768] → [ 196, 768] [14, 14, 768] \rightarrow [196, 768] [14,14,768][196,768], it turns into a two-dimensional matrix at this time, which is exactly what Transformer wants.

Before entering Transformer Encoder, please note that [class]token and Position Embedding need to be added. In the original paper, the author said to refer to BERT and insert a [class]token specifically for classification into the pile of tokens just obtained. This [class]token is a trainable parameter, and the data format is the same as other tokens. The vector, taking ViT-B/16 as an example, is a vector with a length of 768, spliced ​​together with the tokens previously generated from the picture, C a t ( [ 1 , 768 ] , [ 196 , 768 ] ) → [ 197 , 768 ] Cat([1, 768], [196, 768]) \rightarrow [197, 768] Cat([1,768],[196,768])[197,768]. Then, the Position Embedding is the Positional Encoding mentioned in the Transformer before. The Position Embedding here uses a trainable parameter (1D Pos. Emb.), which is directly superimposed on the tokens (add), so the shape must be the same. Taking ViT-B/16 as an example, the shape after splicing [class]token just now is [ 197 , 768 ] [197, 768] [197,768], then the shape of the Position Embedding here is also [ 197 , 768 ] [197 , 768] [197,768]

The author has also done a series of comparative experiments on Position Embedding. In the source code, 1D Pos. Emb. is used by default. Compared with not using Position Embedding, the accuracy has increased by about 3 points, which is not much different from 2D Pos. Emb. .

Transformer Encoder detailed explanation

Transformer Encoder is actually stacking Encoder Block repeatedly.L times, the picture below is the Encoder Block I drew myself, which mainly consists of the following parts:

  • Layer Norm, this Normalization method is mainly proposed in the field of NLP. Here is the Norm processing of each token. I have also talked about Layer Norm before. If you don’t understand, you can refer to the link.

  • Multi-Head Attention, this structure has been discussed in detail before in Transformer, so I won’t go into details. If you don’t understand, you can refer to it.

  • Dropout/DropPath, in the code of the original paper, the Dropout layer is used directly, but in the code implemented by rwightman, DropPath (stochastic depth) is used, and the latter may be better.

  • MLP Block, as shown on the right side of the figure, is a very simple composition of fully connected + GELU activation function + Dropout. It should be noted that the first fully connected layer will quadruple the number of input nodes [ 197 , 768 ] → [ 197 , 3072 ] [197, 768] \rightarrow [197, 3072] [197,768][197,3072], the second fully connected layer will restore the original number of nodes [ 197 , 3072 ] → [ 197 , 768 ] [197, 3072] \rightarrow [197, 768] [197,3072][197,768]

Detailed explanation of MLP Head

After going through the Transformer Encoder, the output shape and the input shape remain unchanged. Taking ViT-B/16 as an example, the input is [197, 768], and the output is still [197, 768]. Note that there is actually a Layer Norm layer that has not been drawn after the Transformer Encoder. There is a ViT model I drew later to see the detailed structure.

In the classification task, we only need to extract the results corresponding to [class]token, that is, [1, 768] corresponding to [class]token is extracted from [197, 768]. Then the final classification result is obtained through MLP Head. In the original paper, when training ImageNet21K, the MLP Head consists of a Linear layer, a tanh activation function, and a Linear layer. But when migrating to ImageNet1K or your own data set, you only need to use a Linear layer.

[External link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly (img-3CPz4RBi-1688792310268)(/Users/zhangkai/Library/Application Support/typora-user-images/image-20230708125431765 .png)]

Vision Transformer network structure drawn by myself

img

"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn


def drop_path(x, drop_prob: float = 0., training: bool = False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({
      
      H}*{
      
      W}) doesn't match model ({
      
      self.img_size[0]}*{
      
      self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x


class Attention(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_c (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_ratio (float): dropout rate
            attn_drop_ratio (float): attention dropout rate
            drop_path_ratio (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)

        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return x


def _init_vit_weights(m):
    """
    ViT weight initialization
    :param m: module
    """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.01)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)


def vit_base_patch16_224(num_classes: int = 1000):
    """
    ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA  密码: eu9f
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=None,
                              num_classes=num_classes)
    return model


def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=768 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_base_patch32_224(num_classes: int = 1000):
    """
    ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg  密码: s5hl
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=None,
                              num_classes=num_classes)
    return model


def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=768 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_large_patch16_224(num_classes: int = 1000):
    """
    ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ  密码: qqt8
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=None,
                              num_classes=num_classes)
    return model


def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=1024 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=1024 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    NOTE: converted weights not currently available, too large for github release hosting.
    """
    model = VisionTransformer(img_size=224,
                              patch_size=14,
                              embed_dim=1280,
                              depth=32,
                              num_heads=16,
                              representation_size=1280 if has_logits else None,
                              num_classes=num_classes)
    return model

Seeing this picture can summarize the process, which is easy to understand.

Insert image description here

Reference:Sunflower’s small mung beans

Guess you like

Origin blog.csdn.net/m0_47005029/article/details/131610694