Detailed explanation of SwinTransformer's overall network structure code

After the release of SwinTransfomer, it has won widespread attention. I have also used this model for experiments and found that the model does have better results. Here, we mainly analyze the network structure in papers and codes in detail.

Network Structure in Papers

insert image description here
The structure of the SwinTransformer Block is, that is, a cascade of a standard W-MSA and SW-MSA:
insert image description here

Details of the network structure

Regarding the details of the overall network, there are very detailed descriptions in the third part of the article. Here, the relevant content in the article is first cut off, and then combined with the code to expand the description in detail.

1 Regarding the input image preprocessing part of the network, the overall network architecture diagram corresponding to this part is shown in the screenshot below.

insert image description here
insert image description here
The corresponding code is as follows:

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]
        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({
      
      H}*{
      
      W}) doesn't match model ({
      
      self.img_size[0]}*{
      
      self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

That is, the code uses a convolution of 4 x 4 with a step size of 4 at the input end, assuming that the original input is 224 x 224. After this process, the image becomes 56 x 56 in size. This process does not use a nonlinear layer , so what is said in the article is to divide the image into 4 x 4 small image blocks, and then perform linear transformation to obtain the embedded feature vector of the image.

2 SwinTransformerBlock

insert image description here
insert image description here
The SwinTransformer module given in the paper is very clear in the label part of the figure. This is the connection of two SwinTransformer Blocks. The difference is that the previous stage is the W-MSA module, and the latter is the SW-MSA module. It can be seen that SwinTransfomer Block includes these two modules. The description of SwinBlock is also explained in the content of the follow-up article. But if you don't read the article or the code carefully, you may think that these two parts are included in a single SwinBlock, but it is not, these are two cascaded SwinBlocks.
insert image description here
Next, let’s take a look at the code: (The part of building SwinBlock in the BasicLayer method, you can see that in fact, W-MSA and SW-MSA appear alternately, specifically selected through the variable shift_size)

# build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer,
                                 fused_window_process=fused_window_process)
            for i in range(depth)])

SwinTransformerBlock method (this should be one of the key parts of this paper):

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 fused_window_process=False):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    #相当于将 img_mask分了9块
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            #num_windows*1, window_size, window_size, 1
            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            #(num_windows,window_size*window_size)
            # ->(num_windows,1,window_size*window_size)-(num_windows,window_size*window_size,1)
            #->(num_windows,window_size*window_size,window_size*window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)
        self.fused_window_process = fused_window_process

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x
            # partition windows
            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C

        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)

        # reverse cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            else:
                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
        else:
            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
            x = shifted_x
        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

In fact, everyone first looks at the forward function in this method, and you can roughly deduce the calculation process of SwinTranformer, which is actually consistent with the description in the paper: what we need to pay attention to is how to carry out the
insert image description here
W-MSA and SW-MSA processes in this part .
First look at the WindowAttention method in this method. Of course, this is a multi-head attention mechanism module (somewhat similar to group convolution). In addition, this code contains the relative position bias mentioned in the paper.
insert image description here

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        # 169 * num_heads
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        #49个位置的值,分别于自身的49个位置值做差,得到相对位置坐标
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        #值从 0 - 168(156=(6+6)*13+12(6+6)) 主对角线上的值为 84
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        #(B,num_heads,N,C // self.num_heads)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))#(B,num_heads,N,N) N = 49
        # 得到Wh*Ww,Wh*Ww,nH每个坐标位置的相对位置偏置量(越靠近中心值越大)
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)#(B,num_heads,N,N)

        if mask is not None:
            nW = mask.shape[0]#n_windows
            #mask:num_windows, Wh*Ww, Wh*Ww-> 1, num_windows, 1,Wh*Ww, Wh*Ww
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

If you understand that relative position encoding is implemented, you can read this section. Among them, relative_position_bias_table stores the position encoding value that can be learned [(2 * window_size[0] - 1) * (2 * window_size[1] - 1)]*num_heads. It should be noted that the encoding value is truncated normal random sampling , the overall value distribution is random, there is no case where the middle is large and the sides are small like a normal distribution, because the tensor is learnable, relative_position_index is the relative position index of a given window size, and its index length is (window_size[ 0] * window_size[1])x(window_size[0] * window_size[1), the size corresponds to the size of the window self-attention module, and the value of each index bit ranges from [0,(2 * window_size [0] - 1) x (2 * window_size[1] - 1) - 1].

 self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        #49个位置的值,分别于自身的49个位置值做差,得到相对位置坐标
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        #值从 0 - 168(156=(6+6)*13+12(6+6)) 主对角线上的值为 84
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)
        trunc_normal_(self.relative_position_bias_table, std=.02)

In fact, in the code of window attention, there is also code related to mask. This mask is generated to cooperate with shift_size. The purpose is to realize the attention mechanism of quantum windows with different offsets under the same window. This is described in the article Related instructions.

if mask is not None:
            nW = mask.shape[0]#n_windows
            #mask:num_windows, Wh*Ww, Wh*Ww-> 1, num_windows, 1,Wh*Ww, Wh*Ww
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
else:
            attn = self.softmax(attn)

insert image description here
Next, let’s take a look at the implementation of this mask. You can find out through the code that the mask is obtained by moving on the img_mask master. We can assume H=W=14 window_size=7 shift_size=3 according to the relationship given in the paper code. 2*2=4 sub-windows will be divided.

if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    #相当于将 img_mask分了9块
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            #num_windows*1, window_size, window_size, 1
            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            #(num_windows,window_size*window_size)
            # ->(num_windows,1,window_size*window_size)-(num_windows,window_size*window_size,1)
            #->(num_windows,window_size*window_size,window_size*window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

Then the output of img_mask at this time
insert image description here
is After the img_mask in the above picture passes through window_partition, it becomes
insert image description here
This basically corresponds to the text. Why should such a mask template be formed? You can continue to read this.
insert image description here

Then, the window after partition is finally output with mask maps of 4 window attentions. These 4 mask maps correspond to the sub-windows after the input tensor (B, C, H, W) is divided into windows, and the sub-windows are realized. The calculation of the window self-attention mechanism (attention calculation between the corresponding positions of the transformed feature map), thereby realizing the calculation of the entire input tensor self-attention mechanism. The following shows the visualization of the final mask binary image.
insert image description here
Related to the mask is the SW-MSA module. This part of the code is the key code, that is, adding the mask to the operator that has been calculated (qxk) and added to the relative position offset, that is, the attention mechanism is realized at its corresponding position Calculated, if a relatively large negative value (-100) is added to the position that does not correspond, the score of this part will become lower after softmax, and the final branch is multiplied by the v variable, which can basically achieve occlusion and The purpose of wanting to calculate location-different location features.

if mask is not None:
            nW = mask.shape[0]#n_windows
            #mask:num_windows, Wh*Ww, Wh*Ww-> 1, num_windows, 1,Wh*Ww, Wh*Ww
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)

The above is basically the calculation process of the sub-attention module in this article. Of course, we mainly talk about W-MSA. Since there is also a module called SW-MSA in the article, before performing SW-MSA, one thing to note is that the The feature map of the original input is appropriately cyclically shifted, but it is worth noting that in the code, the shift_size is (window_size//2) at any stage. Since the window_size is a fixed value, the shift_size is also a fixed value. The content of the paper And the code is as follows:
insert image description here
insert image description here
The effect of the following code corresponds to the above picture.

# cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
            else:
                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
        else:
            shifted_x = x
            # partition windows
            x_windows = window_partition(shifted_x, self.window_size)

After the SW-MSA is calculated, the window will be restored

# reverse cyclic shift
        if self.shift_size > 0:
            if not self.fused_window_process:
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            else:
                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)

The above is what the SwinTransformerBlock module does. Let's talk about a pooling mechanism in this article: PatchMerging.

Pooling mechanism PatchMerging

This module is mainly for the pooling operation of the feature map, which is different from the relatively direct pooling strategy in the convolutional neural network. Here, the horizontal and vertical intervals are used to extract the pixels in the feature map, and the feature map is physically divided by (H, W) becomes (H/2,W/2), and then undergoes a full connection for dimensionality reduction (4 C->2 C), and finally changes the input of (H,W,C) to (H/2,W /2,2C) to achieve a 2x scaling pooling operation.

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({
      
      H}*{
      
      W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

It is worth noting that the part of the illustration in the article combines the Patch Merging module and the SwinBlock module into stage 2. In fact, in the code, Patch Merging is the last step of stage1, but this does not conflict with the illustration in the article.
insert image description here
insert image description here

Summarize

The code of SwinTransformer is indeed worth taking a good look at, so that the technology in this article will be more comprehensively understood. Code words are not easy, if you have any questions, please leave a message.

Guess you like

Origin blog.csdn.net/qq_29750461/article/details/128872215