Swin Transformer
To solve the problem that ViT is not friendly to downstream tasks, a sliding window is proposed
Features of Swin:
-
Starting from a small patch, merge adjacent patches layer by layer
-
Calculate Window Attention
-
Proposed Shifted Window operation to calculate Attention more efficiently
1. Paper reading notes
Swin Transformer uses a moving window to construct a hierarchical ViT, so that ViT can be divided into several blocks like CNN, and can perform hierarchical feature extraction, which has linear computational complexity for image size.
1.1 Summary:
Point out the problem that Transformer is used from NLP to vision:
-
The scale is too large (pedestrians and cars in street view, with various sizes, but not in NLP)
-
If the resolution is too large, the sequence is too long and the calculation is heavy
Previous solution:
-
Use subsequent feature maps as input to Transformer
-
Make the picture into Patch to reduce the resolution of the picture
-
Divide the picture into small windows one by one, and do self-attention in the window.
This paper proposes to move the window, which not only reduces the amount of calculation, but also enables the interaction between two adjacent windows because of the moving operation, so there is a cross-window connection between the upper and lower layers. The benefits of this hierarchical structure Not only is it flexible, it provides information at various scales, and at the same time, self-attention is calculated in a small window, and the computational complexity increases linearly with the image size.
1.2 Conclusion:
Use a small window to calculate self-attention, and ViT calculates self-attention on the whole image. As long as the window size is fixed, the complexity of SA is fixed, and the computational complexity of the entire image increases with the size of the image. There is a linear growth relationship, the image becomes x times, the number of windows increases x times, and the complexity is x times, not the square of x.
Using the local inductive bias in CNN, different parts of the same object (different objects with similar semantics) still have a high probability of appearing in connected places, even in a small-scale window. Attention can be a waste of resources.
The multi-size features in CNN are operated by pooling, which can increase the receptive field that the convolution kernel can see, so that the features after each pooling can capture different sizes of objects. This paper proposes patch merging, so that adjacent The patch synthesizes a large patch, which increases the receptive field and captures multi-scale features. With multi-scale features such as 4×, 8×, and 16×, detection can be done by throwing it to FPN, and segmentation can be done by throwing it to UNET, so Swin Transformer can be used as a backbone network.
After the division is completed, the windows can interact with each other
If a picture is 224×224×3, it is first labeled as a 4×4 patch, and the size of the image in each patch becomes 1/4 of the original, that is, 56, and the dimension becomes 4×4×3; then Linear Embedding converts The vector dimension becomes a pre-set value (a value acceptable to Swin Transformer), the hyperparameter C=96, it becomes 56×56×96 after walking, and becomes 3136×96 after straightening, and the sequence length in ViT It is 16×16, and 3136 is very large at this time. This article is based on the window. There are only 7×7=49 patches in the window. For the time being, it is regarded as a black box, and the self-attention operation is performed in it. If it is not constrained, the size of the input and output is unchanged, that is, the output is still 56×56×96.
In Patch Merging, two 1×1 convolutions are used to reduce the dimension, and the number of channels is changed from 4C to 2C. The purpose is to double the image size and halve the number of channels, the same as the pooling layer.
2. Swin Transformer architecture analysis
Similar to the ViT architecture, Patch Partition (image segmentation) and Patch Embedding are performed on the input image, and then go through 4 stages, similar to the Stage in ResNet. The Stage is mainly composed of Swin Transformer Block, and finally a Patch Merging is performed for fusion.
The most critical operations are Swin Transformer Block and Patch Merging
The figure above shows the general structure of the model, we can also care about the flow of data
Split the picture to see
1. After inputting the patch to the network, if it is a color image, its channel is 3. After Patch Embedding, the number of channels becomes embed_dim
2. After getting the Patch Embedding, use the window (Windows) to cut the patch again. Our current input is already a feature level tensor. Do a Windows Partition to cut into non-overlapping windows
3. If the window is not divided, what we do is each patch and all other batches. Now after dividing, we can do it separately in each window, which can reduce the amount of calculation, and there is no need to calculate each batch and other batches. After we have passed the attention, the output dimension is the same as the input, so after finishing each individual window, the final dimension is the same size tensor
4.Patch Merging
In the swin transformer, the four adjacent image tokens are fused together, and the size of the space becomes smaller, and the dimension of embed_dim is expanded by 2 times at the same time.
5.Next Stage
After one stage is completed, go to the next stage. At this time, the input is a smaller input after merge. Continue to repeat the above steps, cut the window, reduce the size, and increase the dimension
In some stages, block blocks are repeated multiple times, but the size will not be changed, and the input and output dimensions will remain unchanged.
2.1 Swin Transformer Block
This section describes how Block is constructed
It is composed of W-MSA (Window Multi-head Self Attention) and SW-MSA (Shifted Window Multi-head Self Attention). This article will not introduce the moving window part first, but only look at how to deal with the left side. After the data enters, it passes through the LN layer, and then Then go to W-MSA, perform residual connection, and then go to LN, MLP, and residual fusion, which is similar to the previous one, but the W-MSA part needs to be modified
W-MSA
After Tensor divides the Window operation, it only takes out its own Window, and puts out 16 of the tokens for Attention; then takes 16 tokens of window2 itself for attention, each of which is done separately
The paper says that W-MSA is less computationally intensive than MSA, so calculate the formula
It can be seen that the second term of the two formulas is different. In MSA, the size of h·w/(patch_num) grows squarely, but in W-MSA, it has a linear relationship. If the image size is smaller, use W- MSA will be more efficient.
2.2 Patch Merging
Arrange the four parts with different colors in the same window, and arrange the tokens of each part side by side. The dim of the small window obtained after the original merge will become 4 times the original, and the number of tokens will become the original. 1/4 of that, and then do it again, mapping it to 2 times.
Finally, reshape the mapped content back, so the length and width become 1/2 of the original, and the dimension becomes 2 times
3. Code implementation
Involves W-MSA and Patch Merging and Window Partition
Window Partition cuts our tensor into windows and sends them to attention for calculation, so there are three QKVs. Assuming that we have 3 samples in a batch, each sample size is the same, we must cut out the window in the red frame, and each window is individually attentiond.
You can put all the small windows of a batch together. All the windows don’t matter directly. The windows only take care of themselves. No matter how they are arranged, I only count them as my own.
We see a small grid of a square, and we need to calculate the attention of other possessive grids in the window. This is called window_attention, and then pull out the 16 tokens of each small window and expand them, that is
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
def __init__(self,patch_size=4,embed_dim=96):
super().__init__()
self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size,stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.patch_embed(x) #[n, embed_dim, h', w']
x = x.flatten(2) #[n, embed_dim, h'w']
x = x.permute(0, 2, 1) # [n, h'*w', embed_dim]
x = self.norm(x)
return x
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim):
super().__init__()
self.resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear( 4 * dim ,2 * dim)
self.norm = nn.LayerNorm(4 * dim)
def forward(self, x):
h, w = self.resolution
b, _, c = x.shape # _ 不用,其是 num_patches,即 h*w
x = x.reshape([b, h, w, c])
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, :]
x2 = x[:, 1::2, 0::2, :]
x3 = x[:, 1::2, 1::2, :]
x = torch.concat([x0, x1, x2, x3], axis=-1) # [B, h/2, w/2, 4c]
x = x.reshape([b, -1, 4*c])
x = self.norm(x)
x = self.reduction(x)
return x
class Mlp(nn.Module):
def __init__(self, dim, mlp_ratio=4.0, dropout=0.):
super().__init__()
self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
self.fc2 = nn.Linear(int(dim * mlp_ratio),dim)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
def windows_partition(x, window_size):
B, H, W, C = x.shape
x = x.reshape([B, H//window_size,window_size, W//window_size, window_size, C])
x = x.permute([0,1, 3, 2, 4, 5])
# [B, h//ws, w//ws, ws, ws, c]
x = x.reshape([-1, window_size, window_size, C])
# [B * num_patches, ws, ws, c]
return x
def windows_reverse(windows, window_size, H, W):
B = int(windows.shape[0]// (H/window_size * W/window_size))
x = windows.reshape([B, H//window_size, W//window_size, window_size, window_size, -1])
x = x.permute([0, 1 ,3, 2, 4, 5])
x= x.reshape([B, H, W, -1])
return x
WindowAttention is defined, we combine it
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.dim_head = dim// num_heads
self.num_heads = num_heads
self.scale = self.dim_head ** -0.5
self.softmax = nn.Softmax(-1)
self.qkv = nn.Linear(dim,
dim * 3)
self.proj = nn.Linear(dim, dim)
def tranpose_multi_head(self, x):
new_shape = x.shape[:-1] + (self.num_heads, self.dim_head)
x = x.reshape(new_shape)
x = x.permute(0, 2, 1, 3) #[B, num_heads, num_patches, dim_head]
return x
def forward(self,x):
# x: [B, num_patches, embed_dim]
B, N, C = x.shape
qkv = self.qkv(x).chunk(3, -1)
q, k, v = map(self.tranpose_multi_head, qkv)
q = q * self.scale
attn = torch.matmul(q, k.transpose(-1,-2))
attn = self.softmax(attn)
out = torch.matmul(attn, v) # [B, num_heads, num_patches, dim_head]
out = out.permute([0, 2, 1, 3])
# # [B, num_patches, num_heads, dim_head] num_heads * dim_head= embed_dim
out = out.reshape([B, N, C])
out = self.proj(out)
return out
class SwinBlock(nn.Module):
def __init__(self, dim, input_reslution, num_heads, window_size):
super().__init__()
self.dim = dim
self.reolution = input_reslution
self.window_size =window_size
self.attn_norm = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size,num_heads)
self.mlp_norm = nn.LayerNorm(dim)
self.mlp = Mlp(dim)
def forward(self,x):
H, W = self.reolution
B, N, C =x.shape
h = x
s = self.attn_norm(x)
#切 window
x = x.reshape([B, H, W, C])
x_windows = windows_partition(x, self.window_size)
# [B * num_patches, ws, ws, c]
x_windows = x_windows.reshape([-1,self.window_size*self.window_size, C])
attn_windows = self.attn(x_windows)
# 做完attention 将它复原
attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C])
x = windows_reverse(attn_windows, self.window_size, H, W)
# [B, H ,W ,C]
# 但是做mlp中 输入不是它
x = x.reshape([B, H*W, C])
x = h + x
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = h + x
return x
Finally use a main function to call
def main():
t = torch.randn([4, 3, 224, 224])
patch_embedding = PatchEmbedding(patch_size=4, embed_dim=96)
swinBlock = SwinBlock(dim=96, input_reslution=[56,56], num_heads=4, window_size=7)
patch_merging = PatchMerging(input_resolution=[56,56], dim=96)
out = patch_embedding(t) #[4, 56, 56, 96]
print('path_embedding out shape= ',out.shape)
out = swinBlock(out)
print('swinBlock out shape= ',out.shape)
out = patch_merging(out)
print('patch_merging out shape= ',out.shape)
if __name__ == '__main__':
main()
First of all, we input a batch of data, [4, 3, 224, 224], batch_size is 4, we use the patch_embedding operation to take a certain size patch, patch_size is 4, so after the transformation, the tensor becomes [4,56 ,56,96], 3136 is for the convenience of attention in the next step.
In swinBlock, windows_partition and WindowAttention are made, and the dimension size is not changed
Finally, patch_merging is done, similar to pooling, the adjacent 4 tokens are merged, the dimension is expanded by 2 times, that is, 96 becomes 192, and 784 is 28×28, that is, 56×56 is reduced by two times
So WindowAttention is mainly reshape and then changed back