5. Kick into ViT——Swin Transformer (Part 1)

Swin Transformer

To solve the problem that ViT is not friendly to downstream tasks, a sliding window is proposed

Features of Swin:

  1. Starting from a small patch, merge adjacent patches layer by layer

  2. Calculate Window Attention

  3. Proposed Shifted Window operation to calculate Attention more efficiently

1. Paper reading notes

Swin Transformer uses a moving window to construct a hierarchical ViT, so that ViT can be divided into several blocks like CNN, and can perform hierarchical feature extraction, which has linear computational complexity for image size.

1.1 Summary:

Point out the problem that Transformer is used from NLP to vision:

  1. The scale is too large (pedestrians and cars in street view, with various sizes, but not in NLP)

  2. If the resolution is too large, the sequence is too long and the calculation is heavy

Previous solution:

  1. Use subsequent feature maps as input to Transformer

  2. Make the picture into Patch to reduce the resolution of the picture

  3. Divide the picture into small windows one by one, and do self-attention in the window.

This paper proposes to move the window, which not only reduces the amount of calculation, but also enables the interaction between two adjacent windows because of the moving operation, so there is a cross-window connection between the upper and lower layers. The benefits of this hierarchical structure Not only is it flexible, it provides information at various scales, and at the same time, self-attention is calculated in a small window, and the computational complexity increases linearly with the image size.

1.2 Conclusion:

insert image description here

Use a small window to calculate self-attention, and ViT calculates self-attention on the whole image. As long as the window size is fixed, the complexity of SA is fixed, and the computational complexity of the entire image increases with the size of the image. There is a linear growth relationship, the image becomes x times, the number of windows increases x times, and the complexity is x times, not the square of x.

Using the local inductive bias in CNN, different parts of the same object (different objects with similar semantics) still have a high probability of appearing in connected places, even in a small-scale window. Attention can be a waste of resources.

The multi-size features in CNN are operated by pooling, which can increase the receptive field that the convolution kernel can see, so that the features after each pooling can capture different sizes of objects. This paper proposes patch merging, so that adjacent The patch synthesizes a large patch, which increases the receptive field and captures multi-scale features. With multi-scale features such as 4×, 8×, and 16×, detection can be done by throwing it to FPN, and segmentation can be done by throwing it to UNET, so Swin Transformer can be used as a backbone network.

insert image description here

After the division is completed, the windows can interact with each other

insert image description here

If a picture is 224×224×3, it is first labeled as a 4×4 patch, and the size of the image in each patch becomes 1/4 of the original, that is, 56, and the dimension becomes 4×4×3; then Linear Embedding converts The vector dimension becomes a pre-set value (a value acceptable to Swin Transformer), the hyperparameter C=96, it becomes 56×56×96 after walking, and becomes 3136×96 after straightening, and the sequence length in ViT It is 16×16, and 3136 is very large at this time. This article is based on the window. There are only 7×7=49 patches in the window. For the time being, it is regarded as a black box, and the self-attention operation is performed in it. If it is not constrained, the size of the input and output is unchanged, that is, the output is still 56×56×96.

In Patch Merging, two 1×1 convolutions are used to reduce the dimension, and the number of channels is changed from 4C to 2C. The purpose is to double the image size and halve the number of channels, the same as the pooling layer.

insert image description here

insert image description here

insert image description here

2. Swin Transformer architecture analysis

insert image description here

Similar to the ViT architecture, Patch Partition (image segmentation) and Patch Embedding are performed on the input image, and then go through 4 stages, similar to the Stage in ResNet. The Stage is mainly composed of Swin Transformer Block, and finally a Patch Merging is performed for fusion.

The most critical operations are Swin Transformer Block and Patch Merging

The figure above shows the general structure of the model, we can also care about the flow of data

insert image description here

Split the picture to see

1. After inputting the patch to the network, if it is a color image, its channel is 3. After Patch Embedding, the number of channels becomes embed_dim

insert image description here

2. After getting the Patch Embedding, use the window (Windows) to cut the patch again. Our current input is already a feature level tensor. Do a Windows Partition to cut into non-overlapping windows

insert image description here

3. If the window is not divided, what we do is each patch and all other batches. Now after dividing, we can do it separately in each window, which can reduce the amount of calculation, and there is no need to calculate each batch and other batches. After we have passed the attention, the output dimension is the same as the input, so after finishing each individual window, the final dimension is the same size tensor

insert image description here

4.Patch Merging

insert image description here

In the swin transformer, the four adjacent image tokens are fused together, and the size of the space becomes smaller, and the dimension of embed_dim is expanded by 2 times at the same time.

5.Next Stage

insert image description here

After one stage is completed, go to the next stage. At this time, the input is a smaller input after merge. Continue to repeat the above steps, cut the window, reduce the size, and increase the dimension

insert image description here

In some stages, block blocks are repeated multiple times, but the size will not be changed, and the input and output dimensions will remain unchanged.

2.1 Swin Transformer Block

This section describes how Block is constructed

insert image description here

insert image description here

It is composed of W-MSA (Window Multi-head Self Attention) and SW-MSA (Shifted Window Multi-head Self Attention). This article will not introduce the moving window part first, but only look at how to deal with the left side. After the data enters, it passes through the LN layer, and then Then go to W-MSA, perform residual connection, and then go to LN, MLP, and residual fusion, which is similar to the previous one, but the W-MSA part needs to be modified

W-MSA

insert image description here

After Tensor divides the Window operation, it only takes out its own Window, and puts out 16 of the tokens for Attention; then takes 16 tokens of window2 itself for attention, each of which is done separately

insert image description here

The paper says that W-MSA is less computationally intensive than MSA, so calculate the formula

insert image description here

It can be seen that the second term of the two formulas is different. In MSA, the size of h·w/(patch_num) grows squarely, but in W-MSA, it has a linear relationship. If the image size is smaller, use W- MSA will be more efficient.

2.2 Patch Merging

insert image description here

Arrange the four parts with different colors in the same window, and arrange the tokens of each part side by side. The dim of the small window obtained after the original merge will become 4 times the original, and the number of tokens will become the original. 1/4 of that, and then do it again, mapping it to 2 times.

Finally, reshape the mapped content back, so the length and width become 1/2 of the original, and the dimension becomes 2 times

3. Code implementation

Involves W-MSA and Patch Merging and Window Partition

insert image description here

Window Partition cuts our tensor into windows and sends them to attention for calculation, so there are three QKVs. Assuming that we have 3 samples in a batch, each sample size is the same, we must cut out the window in the red frame, and each window is individually attentiond.

You can put all the small windows of a batch together. All the windows don’t matter directly. The windows only take care of themselves. No matter how they are arranged, I only count them as my own.

insert image description here

We see a small grid of a square, and we need to calculate the attention of other possessive grids in the window. This is called window_attention, and then pull out the 16 tokens of each small window and expand them, that is

insert image description here

import torch
import torch.nn as nn


class PatchEmbedding(nn.Module):
    def __init__(self,patch_size=4,embed_dim=96):
        super().__init__()
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size,stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.patch_embed(x)      #[n, embed_dim, h', w']
        x = x.flatten(2)       #[n, embed_dim, h'w']
        x = x.permute(0, 2, 1)  # [n,  h'*w', embed_dim]
        x = self.norm(x)
        return x

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim):
        super().__init__()
        self.resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear( 4 * dim ,2 * dim)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x):
        h, w = self.resolution
        b, _, c = x.shape   # _ 不用,其是 num_patches,即 h*w
        x = x.reshape([b, h, w, c])

        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 1::2, 0::2, :]
        x3 = x[:, 1::2, 1::2, :]

        x = torch.concat([x0, x1, x2, x3], axis=-1)  # [B, h/2, w/2, 4c]
        x = x.reshape([b, -1, 4*c])
        x = self.norm(x)
        x = self.reduction(x)

        return x

class Mlp(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, dropout=0.):
        super().__init__()
        self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
        self.fc2 = nn.Linear(int(dim * mlp_ratio),dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

def windows_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.reshape([B, H//window_size,window_size, W//window_size, window_size, C])
    x = x.permute([0,1, 3, 2, 4, 5])
    # [B, h//ws, w//ws, ws, ws, c]
    x = x.reshape([-1, window_size, window_size, C])
    # [B * num_patches, ws, ws, c]
    return x


def windows_reverse(windows, window_size, H, W):
    B = int(windows.shape[0]// (H/window_size * W/window_size))
    x = windows.reshape([B, H//window_size, W//window_size, window_size, window_size, -1])
    x = x.permute([0, 1 ,3, 2, 4, 5])
    x= x.reshape([B, H, W, -1])
    return x

WindowAttention is defined, we combine it

insert image description here

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.dim_head = dim// num_heads
        self.num_heads = num_heads
        self.scale = self.dim_head ** -0.5
        self.softmax = nn.Softmax(-1)
        self.qkv = nn.Linear(dim,
                             dim * 3)
        self.proj = nn.Linear(dim, dim)

    def tranpose_multi_head(self, x):
        new_shape = x.shape[:-1] + (self.num_heads, self.dim_head)
        x = x.reshape(new_shape)
        x = x.permute(0, 2, 1, 3)  #[B, num_heads, num_patches, dim_head]
        return x

    def forward(self,x):
        # x: [B, num_patches, embed_dim]
        B, N, C = x.shape
        qkv = self.qkv(x).chunk(3, -1)
        q, k, v = map(self.tranpose_multi_head, qkv)

        q = q * self.scale
        attn = torch.matmul(q, k.transpose(-1,-2))
        attn = self.softmax(attn)

        out = torch.matmul(attn, v)   # [B, num_heads, num_patches, dim_head]
        out = out.permute([0, 2, 1, 3])
        #  # [B, num_patches, num_heads, dim_head]    num_heads * dim_head= embed_dim
        out = out.reshape([B, N, C])
        out = self.proj(out)
        return out


class SwinBlock(nn.Module):
    def __init__(self, dim, input_reslution, num_heads, window_size):
        super().__init__()
        self.dim = dim
        self.reolution = input_reslution
        self.window_size =window_size

        self.attn_norm = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size,num_heads)

        self.mlp_norm = nn.LayerNorm(dim)
        self.mlp = Mlp(dim)

    def forward(self,x):
        H, W = self.reolution
        B, N, C =x.shape


        h = x
        s = self.attn_norm(x)
        #切 window
        x =  x.reshape([B, H, W, C])
        x_windows = windows_partition(x, self.window_size)
        # [B * num_patches, ws, ws, c]
        x_windows = x_windows.reshape([-1,self.window_size*self.window_size, C])

        attn_windows = self.attn(x_windows)

        # 做完attention 将它复原
        attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C])
        x = windows_reverse(attn_windows, self.window_size, H, W)
        # [B, H ,W ,C]
        # 但是做mlp中 输入不是它
        x = x.reshape([B, H*W, C])

        x = h + x

        h = x

        x = self.mlp_norm(x)
        x = self.mlp(x)
        x = h + x
        return x

Finally use a main function to call


def main():
    t = torch.randn([4, 3, 224, 224])
    patch_embedding = PatchEmbedding(patch_size=4, embed_dim=96)
    swinBlock = SwinBlock(dim=96, input_reslution=[56,56], num_heads=4, window_size=7)
    patch_merging = PatchMerging(input_resolution=[56,56], dim=96)

    out = patch_embedding(t)  #[4, 56, 56, 96]
    print('path_embedding out shape= ',out.shape)
    out = swinBlock(out)
    print('swinBlock out shape= ',out.shape)
    out = patch_merging(out)
    print('patch_merging out shape= ',out.shape)



if __name__ == '__main__':
    main()

insert image description here

First of all, we input a batch of data, [4, 3, 224, 224], batch_size is 4, we use the patch_embedding operation to take a certain size patch, patch_size is 4, so after the transformation, the tensor becomes [4,56 ,56,96], 3136 is for the convenience of attention in the next step.

In swinBlock, windows_partition and WindowAttention are made, and the dimension size is not changed

Finally, patch_merging is done, similar to pooling, the adjacent 4 tokens are merged, the dimension is expanded by 2 times, that is, 96 becomes 192, and 784 is 28×28, that is, 56×56 is reduced by two times

insert image description here

So WindowAttention is mainly reshape and then changed back

Guess you like

Origin blog.csdn.net/qq_45807235/article/details/129178939