CSwin Transformer study notes

        Cswin proposed the use of cross-shaped local attention in the above figure. In order to solve the problem of further limited growth of the local self-attention receptive field in the VIT model, Cswin also proposed a local enhanced position coding module, which surpassed Swin). SG FormerSOTA (SOTA at that time has been surpassed by SG Former, those who are interested can take a look and other models, effective on multiple tasks

论文地址:https://arxiv.org/abs/2107.00652 

Code address:https://github.com/microsoft/CSWin-Transformer

        The overall structure of the model is as shown above, consisting of token embeeding layers and 4 stageblock is stacked, and eachstage block will be connected with adim design, after each downsampling, theR50. Similar to the typicalfeaturemap layer is used to downsample conv

Research motivation:

  • Based onglobal attentiontransformer works well but Computational complexity and square of feature map size(H==W case)Proportional.
  • Thetransformer based onlocal attention will limit each The interaction of receptive fields of token slows down the growth of receptive fields and requires stacking a large number of block to achieve global self-attention.

Solution:

  • ProposedCross-Shaped Window self-attention mechanism to group attention heads and calculate horizontal and vertical directions in parallel< /span> can achieve better results with a smaller amount of calculation. self-attention
  • ProposedLocally-enhanced Positional Encoding (LePE), which can better process local position information and support arbitrary-shaped input.

1.1 Convolutional Token Embedding

        Use convolution for embedding. In order to reduce the amount of calculation, this article directly uses a 7x7 convolution kernel and a convolution with a stride of 4 to directly embedding the input, and then layernorm the last dimension.

self.stage1_conv_embed = nn.Sequential(
    nn.Conv2d(in_chans, embed_dim, 7, 4, 2),
    Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4),
    nn.LayerNorm(embed_dim)
)

1.2 Cross-Shaped Window Self-Attention

        Specifically, assuming that the original Feature Map is H\times W\times C, in order to calculate its self-attention in the lateral direction, it is first split into M = \frac{H}{sw}Data of horizontal bars (the actual code is processed vertically first), where sw is the width of the horizontal bar. Taking different values ​​in these four different Stages, the experimental results show that the set of values ​​[1, 2, 7, 7] achieves a better balance in speed and accuracy.

        For each strip featureX^{i} ,i=1,2,...M, use Transformer to get its featuresY^{i}, and finally convert these M The input of this head is obtained by splicing these features together. Assuming that it belongs to the kth head, then the calculation method of lateral self-attentionH-Attention_{i}(X) is:

The calculation methods of vertical self-attention V-Attention and H-Attention are similar, except that it takes a vertical bar with a width of sw.

Finally, the output of this block is expressed as:

CSWin self-attention computational complexity analysis:

For high-resolution input, H and W are larger than C in the early stage and smaller than C in the later stage, so sw is small in the early stage and large in the late stage. That is, adjusting sw can effectively expand the attention area of ​​each token in the later stage. In order to make the size of the intermediate feature map of the 224×224 input divisible by sw, the sw of the 4 stages is set to 1, 2, 7, 7 by default.​ 

def img2windows(img, H_sp, W_sp):
    """
    img: B C H W
    """
    B, C, H, W = img.shape
    img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
    img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) # [N*56*1 56 32] [N*56*1 56 32] / [N*14*1 56 64] [N*14*1 56 64] / [N*2*1 98 128] [N*2*1 98 128] / [N*1*1 49 512]
    return img_perm

def windows2img(img_splits_hw, H_sp, W_sp, H, W):
    """
    img_splits_hw: B' H W C
    """
    B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))

    img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) # [N*56*1 56 32]->[N 1 56 56 1 32] [N*56*1 56 32]->[N 56 1 1 56 32] / [N*14*1 56 64]->[N 1 14 28 2 64] [N*14*1 56 64]->[N 14 1 2 28 64] / [N*2*1 98 128]->[N 1 2 14 7 128] [N*2*1 98 128]->[N 2 1 7 14 128] / [N*1*1 49 512]->[N 1 1 7 7 512]
    img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # [N 56 56 32] [N 28 28 64] [N 14 14 128] [N 7 7 512]
    return img

class LePEAttention(nn.Module):
    def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0.,
                 qk_scale=None):
        super().__init__()
        self.dim = dim
        self.dim_out = dim_out or dim
        self.resolution = resolution
        self.split_size = split_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        if idx == -1:
            H_sp, W_sp = self.resolution, self.resolution
        elif idx == 0:
            H_sp, W_sp = self.resolution, self.split_size
        elif idx == 1:
            W_sp, H_sp = self.resolution, self.split_size
        else:
            print("ERROR MODE", idx)
            exit(0)
        self.H_sp = H_sp
        self.W_sp = W_sp
        stride = 1
        self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)

        self.attn_drop = nn.Dropout(attn_drop)

    def im2cswin(self, x):
        B, N, C = x.shape
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)  # [B, N, C] -> [B, C, N] -> [B, C, H, W]
        x = img2windows(x, self.H_sp, self.W_sp)  # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]
        x = x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1,
                                                                                              3).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        return x

    def get_lepe(self, x, func):
        B, N, C = x.shape  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)  # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]

        H_sp, W_sp = self.H_sp, self.W_sp
        x = x.view(B, C, H // H_sp, H_sp, W // W_sp,
                   W_sp)  # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,
                                                             W_sp)  ### B', C, H', W' # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]

        lepe = func(
            x)  ### B', C, H', W' # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14]  / [N*1*1 512 7 7]
        lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,
                                                                                          2).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]

        x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,
                                                                                              2).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        return x, lepe

    def forward(self, qkv):
        """
        x: B L C
        """
        q, k, v = qkv[0], qkv[1], qkv[2]  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]

        ### Img2Window
        H = W = self.resolution  # 56 28 14 7
        B, L, C = q.shape  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
        assert L == H * W, "flatten img_tokens has wrong size"

        q = self.im2cswin(q)  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        k = self.im2cswin(k)  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        v, lepe = self.get_lepe(v, self.get_v)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # B head N C @ B head C N --> B head N N
        attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)
        attn = self.attn_drop(attn)

        x = (attn @ v) + lepe
        x = x.transpose(1, 2).reshape(-1, self.H_sp * self.W_sp,
                                      C)  # B head N N @ B head N C # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]

        ### Window2Img
        x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C)  # B H' W' C

        return x  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]

The code part is actually similar to Swin. If you understand the windowing mechanism of Swin and add head grouping, you can basically quickly understand the ideas in the paper.​ 

1.3 Locally-Enhanced Positional Encoding(LePE)

        Because Transformer is input order-independent, positional encoding needs to be added to it. The left side of the picture above is the PE of the ViT model. It uses absolute position encoding or conditional position encoding. It only enters the transformer together with the token during embedding. The middle one is the PE of Swin, CrossFormer and other models, which uses relative position encoding deviation. By introducing The weight of the token graph is calculated together with the attention, which has better flexibility and better effect than APE.

        The LePE proposed in this article is more direct than RPE. It applies position information to the linear projection. It is also noted that RPE introduces bias in the form of head, while LepE is a per-channel bias, which may show more Strong potential to serve as location embedding. That is to say, the position code is directly added to the Value vector. Assume that the position code is AND. It is added by adding the position codes AND and IN is completed by multiplication. Then the IN added position encoding and IN units weighted by self-attention are added together through a short-cut, the formula is as follows:

        The author here is based on an assumption: for an input element, its nearby elements provide the most important position information. So do a depth convolutionIN on V and add it to the result after softmax. The formula is:

        In this way, LePE can be friendly to downstream tasks that take arbitrary input resolutions as input.

    def get_lepe(self, x, func):
        # func -> self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim)
        B, N, C = x.shape  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)  # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]

        H_sp, W_sp = self.H_sp, self.W_sp
        x = x.view(B, C, H // H_sp, H_sp, W // W_sp,
                   W_sp)  # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,
                                                             W_sp)  ### B', C, H', W' # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]

        lepe = func(
            x)  ### B', C, H', W' # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14]  / [N*1*1 512 7 7]
        lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,
                                                                                          2).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]

        x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,
                                                                                              2).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        return x, lepe

1.4 CSWin Transformer Block

        The structure of CSWin Transformer Block is shown in the figure. Its most significant feature is that it adds two shortcuts and uses LN to normalize the features.

Network structure configuration:

        X^{l} is the output of the l-th Transformer block or the convolutional layer of each stage. 

        CSwin's block has two parts. One is to do LayerNorm and Cross-shaped window self-attention and connect a shortcut. The other is to do LayerNorm and MLP. Compared with Swin and Twins, the calculation amount of the block is greatly reduced. (swin and twins have two attentions + two MLPs stacked into one block).

class CSWinBlock(nn.Module):

    def __init__(self, dim, reso, num_heads,
                 split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None,
                 drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 last_stage=False):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.patches_resolution = reso
        self.split_size = split_size
        self.mlp_ratio = mlp_ratio
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.norm1 = norm_layer(dim)

        if self.patches_resolution == split_size:
            last_stage = True
        if last_stage:
            self.branch_num = 1
        else:
            self.branch_num = 2
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(drop)
        
        if last_stage:
            self.attns = nn.ModuleList([
                LePEAttention(
                    dim, resolution=self.patches_resolution, idx = -1,
                    split_size=split_size, num_heads=num_heads, dim_out=dim,
                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
                for i in range(self.branch_num)])
        else:
            self.attns = nn.ModuleList([
                LePEAttention(
                    dim//2, resolution=self.patches_resolution, idx = i,
                    split_size=split_size, num_heads=num_heads//2, dim_out=dim//2,
                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
                for i in range(self.branch_num)])
        

        mlp_hidden_dim = int(dim * mlp_ratio)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop)
        self.norm2 = norm_layer(dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """

        H = W = self.patches_resolution # 56
        B, L, C = x.shape # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]
        assert L == H * W, "flatten img_tokens has wrong size"
        img = self.norm1(x)
        qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # [3 N 3136 64] [3 N 784 128] [3 N 196 256] [3 N 49 512]
        
        if self.branch_num == 2:
            x1 = self.attns[0](qkv[:,:,:,:C//2]) # qkv[3 N 3136 32]->x1[N 3136 32] qkv[3 N 784 128]->x1[N 784 64] qkv[3 N 196 256]->x1[N 196 128]
            x2 = self.attns[1](qkv[:,:,:,C//2:]) # qkv[3 N 3136 32]->x2[N 3136 32] qkv[3 N 784 128]->x1[N 784 64] qkv[3 N 196 256]->x1[N 196 128]
            attened_x = torch.cat([x1,x2], dim=2)
        else:
            attened_x = self.attns[0](qkv) # [3 N 49 512]->[N 49 512]
        attened_x = self.proj(attened_x)
        x = x + self.drop_path(attened_x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]

In models with similar network parameters and calculation volume, cswin has achieved SOTA in classification tasks and various downstream tasks.

 Detection:

Split:

Guess you like

Origin blog.csdn.net/athrunsunny/article/details/133772022