Neural Network Study Notes 5 - Swin-Transformer Network

Series Article Directory

Neural Network Study Notes 1 - ResNet Residual Network, Batch Normalization Understanding and Code
Neural Network Study Notes 2 - VGGNet Neural Network Structure and Receptive Field Understanding and Code
Reference Blog 1
Reference Blog 2



foreword

What is swin-transformer?

  1. Swim Transformer is a layered Transformer structure specially designed for the visual field. Two major features of Swin are sliding windows and hierarchical representations. Sliding windows compute self-attention in locally non-overlapping windows and allow connections across windows. The hierarchical structure allows the model to adapt to pictures of different scales, and the computational complexity is linearly related to the size of the image, so it is called a CNN with a transformer skin.
  2. Swin Transformer draws on the layered structure of CNN. It can not only do classification, but also can be extended to downstream tasks like CNN. It is a general backbone network for computer vision tasks and can be used for a series of image classification, image segmentation, and target detection. Visual downstream tasks.
  3. It takes VIT as the starting point, and the design concept absorbs the essence of resnet. From local to global, transformer is designed as a tool to gradually expand the receptive field. Its success is not accidental, but thick accumulation and precipitation.

What problem did it solve?

  1. Different from the NLP field, the scales of similar objects in the visual field will vary greatly on different images/on the same image. In the same multi-pedestrian image, pedestrians will be large or small, near or far, and different targets under the same semantics. The scale gap may vary greatly;
  2. Compared with text, the size of the image is too large and the computational complexity is high;
  3. Compared with the previous ViT, two improvements have been made: the hierarchical construction method commonly used in CNN is introduced to build a hierarchical Transformer; the idea of ​​locality is introduced to perform self-attention calculations in non-overlapping window areas.

What are the advantages?

  1. A hierarchical network structure is proposed to solve the multi-scale problem of visual images and provide dimensional information of each scale;
  2. The Shifted Windows mobile window is proposed, which brings greater efficiency. The mobile operation allows adjacent windows to interact, which greatly reduces the computational complexity of the transformer;
  3. Computational complexity grows linearly rather than flatly, and can be widely applied to all computer vision fields;

What is the conclusion?

  1. Transformer can completely replace CNN in various fields, and is known as a new direction in the CV field, a new era

How is the effect?

  1. It is not SOTA on ImageNet, only about the same performance as EfficientNet
  2. The advantage of swin-transformer is not in classification, the improvement in classification is not too much, but in downstream tasks such as detection and segmentation, there is a huge improvement.
  3. The paper was published in March 2021, and once it was published, it has topped the list in multiple visual tasks.

1. Patch Merging operation

insert image description here
ViT uses a patch size of 16×16, that is, a downsampling rate of 16 times. From low to high, the size of each patch of these tokens does not change. Global modeling is achieved through global self-attention operations. However, learning in the face of multi-size objects will be poor, and a single-size processing is the main approach. Moreover, the sequence length is still too large when facing large pictures, and the computational complexity increases in a flat manner.

In intensive predictive tasks such as detection and segmentation or images used in landing projects, multi-scale issues are very important issues, and mature models will have special multi-scale feature processing methods.

Swin transformer performs self-attention in small windows (the window concept is in the second block). The small windows composed of these patches are different from ViT's patches and are relatively independent. For example, in 4 times downsampling, the feature map is divided into multiple disjoint small window areas, and Multi-Head Self-Attention is only performed within each window patch.
Faced with how disjoint windows transmit information and how to learn multi-scale information, it proposes patch merging, which simply means combining a small window patch with a large window patch to increase the receptive field, and then extracting depth features by selecting serial numbers Figure, simulating a pooling-like operation.
In detail, downsampling is performed through a Patch Merging layer, as shown in the figure below. For example, if you want to downsample twice, first combine four small patches into a large patch, and then use the serial numbers 1, 2, 3, and 4 on the small patch to perform Extraction, when extracting, click every other one, that is, select the same serial number, and the patch at the same serial number position will be merged. After extraction, the original tensor becomes four tensors, concat splicing is performed in the depth direction, and the dimension changes from h × w × c to h/2 × w/2 × 4c, and then passes through a LayerNorm layer . Because it is analogous to the CNN mode, the number of channels will only double after each pooling, so here I only want it to double instead of 4 times, so I did another operation immediately, that is, in c Use a 1x1 convolution (or fully connected layer) in the dimension to reduce the number of channels to 2c, and finally get the output of h/2 × w/2 × 2c. That is, after passing the Patch Merging layer, the height and width of the feature map will be halved, and the depth will be doubled.
insert image description here


2. W-MSA, SW-MSA and cyclic shift window design

1. Window self-attention W-MSA

The calculation of global self-attention like ViT will lead to the complexity of square times. Similarly, when doing downstream tasks in vision, especially dense prediction tasks, or when encountering very large-sized pictures, this global calculation Comparing the computational complexity of self-attention to convolution, there will be a big difference in computing power.

The article proposes to use the window method to do self-attention, that is, Windows Multi-head Self-Attention (W-MSA). The W-MSA module is to reduce the amount of calculation.

As shown in the figure below, the ordinary Multi-head Self-Attention (MSA) module is used on the left. For each pixel (or called token or patch) and Class sequence in the feature map, the Self-Attention calculation process requires and all pixels to calculate the global.

But on the right side of the figure, the feature map is split into non-overlapping windows. When using the W-MSA module, the feature map is first divided into windows according to the size of M×M (M=2 in the example), and then Self-Attention is performed on each Windows internally.

Assuming that a 224×224×3 picture is input in the Swin transformer, then the size of a patch is divided into 4×4, then there are 56×56 patches, and every 7 patches form a window, that is, a window has 7 ×7 patches, a 224×224×3 picture will have 8×8=64 windows.
insert image description here

The following two formulas are given in the original paper, where the computational complexity of Softmax is ignored:
insert image description here

  1. h represents the height of the feature map
  2. w represents the width of the feature map
  3. C represents the depth of the feature map
  4. M represents the size of each window (Windows), patch is the unit

Comparing formula (1) and formula (2), although the first two items of the two formulas are the same, only the latter changes from (h w) ^ 2 to M^2 * h * w, it seems that there is not much difference , but in fact, if you carefully bring the numbers into the calculation, you will find that the gap in the calculation complexity is quite huge, because if h w here is 56*56, M^2 is actually only 49, so the difference is dozens or even above hundred times.

2. Moving window self-attention SW-MSA

insert image description here
Transformer's original intention is to understand the context, which is a kind of information transmission and interaction. When using the W-MSA module, it will only perform self-attention calculations in each window, so information cannot be transmitted between windows. In order to solve this problem, the author introduced the Shifted Windows Multi-Head Self-Attention (SW-MSA) module, that is, the shifted W-MSA. According to the comparison of the left and right images, it can be found that the window has shifted (it can be understood that the window is shifted from the upper left corner to the right and down by M /2 patches respectively).

The L1 layer uses W-MSA, and the L1+1 layer uses SW_MSA. At L1, the patch in each window can only learn from the patch in the same window. When it comes to L1+1 layer, due to the window The movement of , causing some patches to enter a new window, these patches with the previous window information can learn from each other with other patches with the previous window information of the previous layer. This is the cross window connection operation, which enables interaction between windows. Combined with the Patch Merging operation, in the last few layers, each patch has communicated with most of the patches in the feature map, that is, the receptive field is already large and most of the pictures can be seen. These local attention information will eventually spread to the whole world, and achieve the effect of global attention in disguise. To put it simply, the 4×4 window in the center of the L1+1 layer learns and fuses the information of the four windows in the L1 layer, because the source composition of the 4×4 window in the center is the patch of the four windows in the L1 layer, and the four windows in the L1 layer When the patch passes through W-MSA, it has already learned the information of the window it is in, and has the information of its own window. Therefore, the learning of the 4×4 window in the center of the L1+1 layer is the fusion of most of the adjacent information of the four windows of the L1 layer.

3. Optimized cyclic shift for window movement

insert image description here

In fact, there is also a problem with the movement of the SW-MSA window, although the patch in the window can communicate with the patches of other windows and exchange information with other windows. However, there is a problem before and after the move. For example, before the move, L1 has four windows, and each window has 16 patches. After the move, L1+1 has nine windows, and each window has a different size, which is 4. \8\16 patches.

insert image description hereinsert image description here
There is a simple method, which is to fill in zeros. For example, add 12 0s to the windows of four patches to form a window format of 16 patches. In this way, 4 is filled with 12, and 8 is filled with 8. After the completion of the patch, 9 windows are obtained. The 9 windows are then batched into batches for learning. Although the method is straightforward and simple, the number of windows in a batch is increased from 4 to 9. In fact, the amount of calculation is increased and the complexity is also increased.

insert image description here
Swin transformer proposes to use the mask to do a cyclic shift. The specific method is:

  1. Temporarily number the incomplete windows above and to the left of L1+1 as A, B, and C.
  2. Move the incomplete windows of A, B and C to the bottom and right of L1+1, A is the opposite corner, and B and C are opposite.
  3. Divide the newly created feature map into 4 windows, in which the original central 16patch window remains unchanged, and the other incomplete windows are spliced ​​into a new 16patch window.

This kind of operation realizes the patch exchange of different windows, and does not increase the window and increase the computational complexity like the zero padding operation. However, a new problem arises, that is, the original center 16 patch window remains unchanged, and the patches inside are originally neighbors in the sense of pixels, which are related and can be self-attentioned in pairs. But for the other three spliced ​​16patch windows, they are feature maps from different regions. If they do self-attention between them, the learned features may be chaotic, that is to say, they cannot be regarded as a pure window. to do self-attention.

How to deal with the splicing window, Swin transformer proposes to use the mask masked operation.
insert image description here
For example, here is a 14×14×3 feature map that has been moved and spliced. The 0th window occupies 7×7 patches, and the 1st and 3rd are 4× 7 patches, No. 2 and No. 6 are 3×7 patches, No. 4 is 4×4, No. 5 and No. 7 are 3×4, No. 8 is 3×3, a total of 14×14 patches (window from the upper left The corners are offset by M /2 patches to the right and down respectively).

Window No. 0 is a complete window that can use self-attention directly. No. 3 and No. 6 belong to spliced ​​windows, and it cannot directly use self-attention itself.

insert image description hereinsert image description hereinsert image description hereinsert image description here

So first perform the previous operation to extract the patches of square 3 and square 6, and stretch them into a vector A. In this vector A, the value of patch 3 is 4×7=28, and the value of patch 6 is 3 ×7=21 pieces. Then transpose the vector A to get the vector B. The self-attention calculation is performed through the matrix multiplication of vectors A and B, and the self-attention matrix C is obtained. The matrix C can be specifically divided into four types, namely:

  1. Multiply all No. 3 patch values ​​of vector A with all No. 3 patch values ​​of vector B, 3×3
  2. Multiply all the No. 3 patch values ​​of vector A with all No. 6 patch values ​​of vector B, 3×6
  3. Multiply all the No. 6 patch values ​​of vector A with all No. 3 patch values ​​of vector B, 6×3
  4. Multiply all the No. 6 patch values ​​of vector A with all No. 6 patch values ​​of vector B, 6×6

Among them, 3×3 and 6×6 are in line with the concept of self-attention, 3×6 and 6×3 are the chaotic values ​​of splicing, so we only need 3×3 and 6×6 data, and 3×6 and 6× 3 needs to be masked.

So how to deal with the window of block 3 + block 6? Swin transformer proposes a clever idea to use a mask template matrix D to add matrix C to matrix D. Originally, the values ​​in matrix C are very small. value (probably a value below 0 points), 3×3 and 6×6 data plus 0 will not change, and 3×6 and 6×3 data plus -100 will become a very Large negative numbers, this is to perform softmax operations on these values, then those negative numbers will be returned to 0, and the rest is the 3×3 and 6×6 data we need.

insert image description here

After talking about the 3+6 window, continue to look at the 1+2 window. Combining the above ideas, we can find that the 1+2 window is very different from the 3+6 window. This difference is caused by the straightening vector A.
insert image description hereinsert image description hereinsert image description here

It can be seen that the values ​​of patch 1 and patch 2 in vector A are interleaved, which also leads to changes in transposition vector B and self-attention matrix C.

Mainly talk about the changes of matrix C, it is still divided into 4 types (1×1, 1×2, 2×1, 2×2), but it is no longer centralized and regionalized, but a horizontal and vertical stripe Go Format matrix, this change also led to the design of the mask template matrix D. Since this format is more troublesome, I did not draw it one by one. You can refer to the mask template provided by Swin transformer.

insert image description here

As for the 4+5+7+8 window, it is actually a combination of 3+6 and 1+2 windows. I drew a picture of straightening vector A. You can understand it yourself. You need to combine 3+6 and 1+2 the law. The specific mask template is shown in the figure above.
insert image description here

After completing the multi-head self-attention, it is necessary to restore the spliced ​​feature map to ensure that its relative position and semantic information remain unchanged. If it is not restored, then when the cycle turns to the next Blocks module, the learned W-MSA is chaotic. When learning SW-MSA, the moved feature map will continue to be split and spliced, spliced ​​towards the lower right corner, and learn after multiple rounds. The acquired features will become more and more chaotic, and the feature map will also be in a state of constant disruption.


3. Swin Transformer Blocks module

Swin Transformer Blocks has two structures. The difference is that the calculation of window multi-head self-attention uses the W-MSA structure and the other uses the SW-MSA structure. And these two structures are used in pairs, first use a W-MSA structure and then use a SW-MSA structure. Therefore, the number of stacked Swin Transformer Blocks is an even number (the ×2 and ×6 under Swin Transformer Blocks in the overall model are because they are used in pairs).
insert image description here
insert image description here

Combining pictures and formulas for forward simulation:

  1. Pass in a sequence Z (l-1) of the input format [H/4n, W/4n, nC] .
  2. Pass in the LayerNorm layer to perform window self-attention W-MSA operation after normalization.
  3. The W-MSA output is added to Z (l-1) , and the output is Z' l .
  4. Then pass in the LayerNorm layer to perform MLP operation after normalization. Note that the channel depth of the MLP Block input will be ×4, and the output will be ÷4.
  5. The MLP output is added to Z' l , and Z l is output .
  6. Pass in a sequence Z l of the input format [H/4n, W/4n, nC] .
  7. The moving window self-attention SW-MSA operation is performed after passing in the LayerNorm layer for normalization.
  8. The SW-MSA output is added to Z l , and the output is Z' l+1 .
  9. Then pass in the LayerNorm layer to perform MLP operation after normalization.
  10. The MLP output is added to Z' l+1 , and Z l+1 is output .

4. Overall model understanding

insert image description here

Forward process:

  1. Input a picture Images whose size is H×W×3.
  2. Execute patch partition, that is, divide the picture into H/4×W/4×48 patches (H/4×W/4×48=H×W×3), but do not set ClS token.
  3. Execute Linear Embedding to perform linear transformation on the channel data of each pixel, that is to say, to change the dimension of the vector into a pre-set value C (the setting of the C value is related to its type, and resnet has 18\34\ 50\101\152 type), this C is a hyperparameter. The shape of the image is changed from [H/4, W/4, 48] to [H/4, W/4, C]. The input data [H/4, W/4] is straightened into H/4×W/4=HW/16 sequence length, and C becomes the vector dimension of each token.
  4. Because the sequence length of [HW/16] patches is relatively long, for example, input a picture of 224×224×3, then the patch is divided into 56×56×48, then the sequence length is 3136patch, C is 96, compared with ViT’s 196patch is too long, so the self-attention calculation given to the window is introduced. Each window is generally set to a sequence length of 7×7=49 patches.
  5. The [H/4, W/4, C] sequence is input to the Block of Stage1, and the output is also [H/4, W/4, C], ×2.
  6. Build a hierarchical transformer to extract multi-scale information, and pass the [H/4, W/4, C] output from the Block of Stage1 to the Patch Merging module to perform a similar pooling downsampling operation. After performing the Patch Merging operation, the output value is [H/8, W/8, 4C], which simulates the effect of doubling the channel depth of the convolution model, and performs 1×1 convolution to reduce 4C to 2C.
  7. The sequence [H/8, W/8, 2C] is input to the Block of Stage2, and the output is also [H/8, W/8, 2C], ×2.
  8. Pass the [H/8, W/8, 2C] output from the Block of Stage3 to the Patch Merging module, and the output value after the Patch Merging operation is [H/16, W/16, 8C], simulating the channel of the convolution model Depth doubling effect, performing 1×1 convolution to reduce 8C to 4C.
  9. The sequence [H/16, W/16, 4C] is input to the Block of Stage3, and the output is also [H/16, W/16, 4C], ×6.
  10. Pass the [H/16, W/16, 4C] output from the Block of Stage4 to the Patch Merging module, and the output value after the Patch Merging operation is [H/32, W/32, 16C], simulating the channel of the convolution model Depth doubling effect, performing 1×1 convolution to reduce 16C to 8C.
  11. The sequence [H/32, W/32, 8C] is input to the Block of Stage4, and the output is also [H/32, W/32, 8C], ×2.
  12. The above is the overall backbone network. If it is used for image classification, in order to be consistent with the convolutional neural network, the Swin Transformer paper does not use the CLS token like ViT, but connects a Layer Norm layer, a global pooling layer and The fully connected layer gets the final output. (The author did not draw in this picture, because the original intention of Swin Transformer is not only for classification, but also for detection and segmentation, so it only draws the backbone part, and does not draw the final classification head or detection head)

Refer to other similar pictures:
insert image description here

Five, relative position bias Relative Position Bias parameter

insert image description here

When doing experiments, Swin transformer said that doing SW-MSA is better than doing only W-MSA, and doing relative position rel. pos. is better than doing absolute position abs. pos. and no position no pos.

insert image description here

insert image description here
Going back to the paper formula, you will find that its self-attention formula has an additional +B operation, and this B is the relative position bias.

Borrow and refer to this blogger's understanding and diagram

1. Assuming that the size of the window M×M is 2×2patch, when calculating the self-attention in the window, first calculate the relative position index. The index here is not a bias B, but an element that constitutes B.
insert image description here

2. There is a concept to be distinguished here, that is, absolute position and relative position. Relative position can be understood in combination with the reference system, which is calculated by subtracting itself and the reference object from the reference subject.

3. A corresponding relative position index can be calculated for different reference subject patches, and each calculated index is flattened and spliced ​​into a matrix A, and the size of the matrix A is (M×M) 2 .

4. We can observe the relative position index distribution law in the A matrix, such as the position concept of the right position, the right side of the red patch is blue, with red as the reference subject, and the relative position of blue is [0,-1]. Another example is that the right side of the yellow patch is green, and the relative position will be [0,-1]. After careful observation, you will find that up, down, left, right, left, left, right, right, up, down and other positional concepts have the same index.

5. Through the hash transformation of the rows and columns, the 2D relative position index is changed to 1D to simplify the calculation. The hash formula used by the author here is (x+M-1)×(2M-1)+(y+M- 1), where x is the row, y is the column number, and M is the window size.

6. Reducing 2D to 1D can be achieved by simple methods such as addition or multiplication, but there will be repeated values. For example, the right side of the red patch is [0,-1], and the bottom is [-1,0]. In 2D It is still obvious, but when adding (-1) or multiplying (0), there will be interference with different inputs but the same output. At this time, a hash algorithm can be used to solve this problem.

insert image description here
insert image description here
insert image description here
insert image description here
7. Realize the unique value of the relative position. There are a total of (2M-1)×(2M-1) types of relative position indexes, then (2M-1)*(2M-1) random relative position offsets can be randomly generated ( nn.Parameter learnable parameters), according to the relative position index, to obtain the corresponding relative position offset, which is B in the formula, to calculate the multi-head self-attention.
insert image description here

6. Swin transformer category

insert image description here

  1. win.sz.7x7 indicates the size of the window (Windows) used
  2. Dim represents the channel depth of the feature map (or the vector length of the token)
  3. head indicates the number of heads in the multi-head attention module

To be continued. . .

Guess you like

Origin blog.csdn.net/qq_45848817/article/details/127105956