After the release of SwinTransfomer, it has won widespread attention. I have also used this model for experiments and found that the model does have better results. Here, we mainly analyze the network structure in papers and codes in detail.
Network Structure in Papers
The structure of the SwinTransformer Block is, that is, a cascade of a standard W-MSA and SW-MSA:
Details of the network structure
Regarding the details of the overall network, there are very detailed descriptions in the third part of the article. Here, the relevant content in the article is first cut off, and then combined with the code to expand the description in detail.
1 Regarding the input image preprocessing part of the network, the overall network architecture diagram corresponding to this part is shown in the screenshot below.
The corresponding code is as follows:
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({
H}*{
W}) doesn't match model ({
self.img_size[0]}*{
self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
That is, the code uses a convolution of 4 x 4 with a step size of 4 at the input end, assuming that the original input is 224 x 224. After this process, the image becomes 56 x 56 in size. This process does not use a nonlinear layer , so what is said in the article is to divide the image into 4 x 4 small image blocks, and then perform linear transformation to obtain the embedded feature vector of the image.
2 SwinTransformerBlock
The SwinTransformer module given in the paper is very clear in the label part of the figure. This is the connection of two SwinTransformer Blocks. The difference is that the previous stage is the W-MSA module, and the latter is the SW-MSA module. It can be seen that SwinTransfomer Block includes these two modules. The description of SwinBlock is also explained in the content of the follow-up article. But if you don't read the article or the code carefully, you may think that these two parts are included in a single SwinBlock, but it is not, these are two cascaded SwinBlocks.
Next, let’s take a look at the code: (The part of building SwinBlock in the BasicLayer method, you can see that in fact, W-MSA and SW-MSA appear alternately, specifically selected through the variable shift_size)
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
fused_window_process=fused_window_process)
for i in range(depth)])
SwinTransformerBlock method (this should be one of the key parts of this paper):
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
fused_window_process=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
#相当于将 img_mask分了9块
img_mask[:, h, w, :] = cnt
cnt += 1
#num_windows*1, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
#(num_windows,window_size*window_size)
# ->(num_windows,1,window_size*window_size)-(num_windows,window_size*window_size,1)
#->(num_windows,window_size*window_size,window_size*window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
In fact, everyone first looks at the forward function in this method, and you can roughly deduce the calculation process of SwinTranformer, which is actually consistent with the description in the paper: what we need to pay attention to is how to carry out the
W-MSA and SW-MSA processes in this part .
First look at the WindowAttention method in this method. Of course, this is a multi-head attention mechanism module (somewhat similar to group convolution). In addition, this code contains the relative position bias mentioned in the paper.
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
# 169 * num_heads
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
#49个位置的值,分别于自身的49个位置值做差,得到相对位置坐标
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
#值从 0 - 168(156=(6+6)*13+12(6+6)) 主对角线上的值为 84
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
#(B,num_heads,N,C // self.num_heads)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))#(B,num_heads,N,N) N = 49
# 得到Wh*Ww,Wh*Ww,nH每个坐标位置的相对位置偏置量(越靠近中心值越大)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)#(B,num_heads,N,N)
if mask is not None:
nW = mask.shape[0]#n_windows
#mask:num_windows, Wh*Ww, Wh*Ww-> 1, num_windows, 1,Wh*Ww, Wh*Ww
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
If you understand that relative position encoding is implemented, you can read this section. Among them, relative_position_bias_table stores the position encoding value that can be learned [(2 * window_size[0] - 1) * (2 * window_size[1] - 1)]*num_heads. It should be noted that the encoding value is truncated normal random sampling , the overall value distribution is random, there is no case where the middle is large and the sides are small like a normal distribution, because the tensor is learnable, relative_position_index is the relative position index of a given window size, and its index length is (window_size[ 0] * window_size[1])x(window_size[0] * window_size[1), the size corresponds to the size of the window self-attention module, and the value of each index bit ranges from [0,(2 * window_size [0] - 1) x (2 * window_size[1] - 1) - 1].
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
#49个位置的值,分别于自身的49个位置值做差,得到相对位置坐标
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
#值从 0 - 168(156=(6+6)*13+12(6+6)) 主对角线上的值为 84
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
In fact, in the code of window attention, there is also code related to mask. This mask is generated to cooperate with shift_size. The purpose is to realize the attention mechanism of quantum windows with different offsets under the same window. This is described in the article Related instructions.
if mask is not None:
nW = mask.shape[0]#n_windows
#mask:num_windows, Wh*Ww, Wh*Ww-> 1, num_windows, 1,Wh*Ww, Wh*Ww
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
Next, let’s take a look at the implementation of this mask. You can find out through the code that the mask is obtained by moving on the img_mask master. We can assume H=W=14 window_size=7 shift_size=3 according to the relationship given in the paper code. 2*2=4 sub-windows will be divided.
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
#相当于将 img_mask分了9块
img_mask[:, h, w, :] = cnt
cnt += 1
#num_windows*1, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
#(num_windows,window_size*window_size)
# ->(num_windows,1,window_size*window_size)-(num_windows,window_size*window_size,1)
#->(num_windows,window_size*window_size,window_size*window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
Then the output of img_mask at this time
is After the img_mask in the above picture passes through window_partition, it becomes
This basically corresponds to the text. Why should such a mask template be formed? You can continue to read this.
Then, the window after partition is finally output with mask maps of 4 window attentions. These 4 mask maps correspond to the sub-windows after the input tensor (B, C, H, W) is divided into windows, and the sub-windows are realized. The calculation of the window self-attention mechanism (attention calculation between the corresponding positions of the transformed feature map), thereby realizing the calculation of the entire input tensor self-attention mechanism. The following shows the visualization of the final mask binary image.
Related to the mask is the SW-MSA module. This part of the code is the key code, that is, adding the mask to the operator that has been calculated (qxk) and added to the relative position offset, that is, the attention mechanism is realized at its corresponding position Calculated, if a relatively large negative value (-100) is added to the position that does not correspond, the score of this part will become lower after softmax, and the final branch is multiplied by the v variable, which can basically achieve occlusion and The purpose of wanting to calculate location-different location features.
if mask is not None:
nW = mask.shape[0]#n_windows
#mask:num_windows, Wh*Ww, Wh*Ww-> 1, num_windows, 1,Wh*Ww, Wh*Ww
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
The above is basically the calculation process of the sub-attention module in this article. Of course, we mainly talk about W-MSA. Since there is also a module called SW-MSA in the article, before performing SW-MSA, one thing to note is that the The feature map of the original input is appropriately cyclically shifted, but it is worth noting that in the code, the shift_size is (window_size//2) at any stage. Since the window_size is a fixed value, the shift_size is also a fixed value. The content of the paper And the code is as follows:
The effect of the following code corresponds to the above picture.
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size)
After the SW-MSA is calculated, the window will be restored
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
The above is what the SwinTransformerBlock module does. Let's talk about a pooling mechanism in this article: PatchMerging.
Pooling mechanism PatchMerging
This module is mainly for the pooling operation of the feature map, which is different from the relatively direct pooling strategy in the convolutional neural network. Here, the horizontal and vertical intervals are used to extract the pixels in the feature map, and the feature map is physically divided by (H, W) becomes (H/2,W/2), and then undergoes a full connection for dimensionality reduction (4 C->2 C), and finally changes the input of (H,W,C) to (H/2,W /2,2C) to achieve a 2x scaling pooling operation.
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({
H}*{
W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
It is worth noting that the part of the illustration in the article combines the Patch Merging module and the SwinBlock module into stage 2. In fact, in the code, Patch Merging is the last step of stage1, but this does not conflict with the illustration in the article.
Summarize
The code of SwinTransformer is indeed worth taking a good look at, so that the technology in this article will be more comprehensively understood. Code words are not easy, if you have any questions, please leave a message.