CSwin Transformer 学习笔记

        Cswin提出了上图中使用交叉形状局部attention,为了解决VIT模型中局部自注意力感受野进一步增长受限的问题,同时提出了局部增强位置编码模块,超越了Swin等模型,在多个任务上效果SOTA(当时的SOTA,已经被SG Former超越,感兴趣的可以看看SG Former)。

论文地址:https://arxiv.org/abs/2107.00652 

代码地址:https://github.com/microsoft/CSWin-Transformer

        模型整体结构如上所示,由token embeeding layer4stageblock所堆叠而成,每个stage block后面都会接入一个conv层,用来对featuremap进行下采样。和典型的R50设计类似,每次下采样后,会增加dim的数量,一是为了提升感受野,二是为了增加特征性。

研究动机:

  • 基于global attentiontransformer效果虽然好但是计算复杂度与特征图大小平方(H==W的情况)成正比
  • 基于local attentiontransformer的会限制每个token的感受野的交互,减缓感受野的增长,需要堆叠大量的block来实现全局自注意力。

解决办法:

  • 提出了Cross-Shaped Window self-attention机制,对注意力头进行分组,并行计算水平和竖直方向的self-attention,可以在更小的计算量条件下获得更好的效果。
  • 提出了Locally-enhanced Positional Encoding(LePE), 可以更好的处理局部位置信息,并且支持任意形状的输入。

1.1 Convolutional Token Embedding

        用convolution来做embedding,为了减少计算量,本文直接采用了7x7的卷积核,stride为4的卷积来直接对输入进行embedding,之后再对最后一维进行layernorm。

self.stage1_conv_embed = nn.Sequential(
    nn.Conv2d(in_chans, embed_dim, 7, 4, 2),
    Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4),
    nn.LayerNorm(embed_dim)
)

1.2 Cross-Shaped Window Self-Attention

        具体来讲,假设原始的Feature Map为H\times W\times C,为了计算它在横向上的自注意力,它首先被拆分成M = \frac{H}{sw}个横条的数据(实际代码先进行竖列处理),其中sw是横条的宽度。在这4个不同的Stage中取不同的值,实验结果表明[1,2,7,7]这组值在速度和精度上取得了比较好的均衡。

        对于每个条状特征X^{i} ,i=1,2,...M,使用Transformer可以得到它的特征Y^{i},最后将这M个特征拼接到一起便得到了这个head的输入。假设它属于第k个head,那么横向自注意力H-Attention_{i}(X)的计算方式为:

纵向自注意力V-Attention 和H-Attention的计算方式类似,不同的是它是取的宽度为sw的竖条。

最终,这个block的输出表示为:

CSWin self-attention计算复杂度分析:

对于高分辨率输入,H,W早期大于C,后期小于C,因此早期sw小,后期大。即,调整sw可以有效地扩大后期每个token的attention区域。为了使224×224输入的中间特征图大小可被sw整除,默认将4个阶段的sw设置为1、2、7、7。 

def img2windows(img, H_sp, W_sp):
    """
    img: B C H W
    """
    B, C, H, W = img.shape
    img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
    img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp* W_sp, C) # [N*56*1 56 32] [N*56*1 56 32] / [N*14*1 56 64] [N*14*1 56 64] / [N*2*1 98 128] [N*2*1 98 128] / [N*1*1 49 512]
    return img_perm

def windows2img(img_splits_hw, H_sp, W_sp, H, W):
    """
    img_splits_hw: B' H W C
    """
    B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))

    img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) # [N*56*1 56 32]->[N 1 56 56 1 32] [N*56*1 56 32]->[N 56 1 1 56 32] / [N*14*1 56 64]->[N 1 14 28 2 64] [N*14*1 56 64]->[N 14 1 2 28 64] / [N*2*1 98 128]->[N 1 2 14 7 128] [N*2*1 98 128]->[N 2 1 7 14 128] / [N*1*1 49 512]->[N 1 1 7 7 512]
    img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # [N 56 56 32] [N 28 28 64] [N 14 14 128] [N 7 7 512]
    return img

class LePEAttention(nn.Module):
    def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, attn_drop=0., proj_drop=0.,
                 qk_scale=None):
        super().__init__()
        self.dim = dim
        self.dim_out = dim_out or dim
        self.resolution = resolution
        self.split_size = split_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        if idx == -1:
            H_sp, W_sp = self.resolution, self.resolution
        elif idx == 0:
            H_sp, W_sp = self.resolution, self.split_size
        elif idx == 1:
            W_sp, H_sp = self.resolution, self.split_size
        else:
            print("ERROR MODE", idx)
            exit(0)
        self.H_sp = H_sp
        self.W_sp = W_sp
        stride = 1
        self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)

        self.attn_drop = nn.Dropout(attn_drop)

    def im2cswin(self, x):
        B, N, C = x.shape
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)  # [B, N, C] -> [B, C, N] -> [B, C, H, W]
        x = img2windows(x, self.H_sp, self.W_sp)  # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]
        x = x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1,
                                                                                              3).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        return x

    def get_lepe(self, x, func):
        B, N, C = x.shape  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)  # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]

        H_sp, W_sp = self.H_sp, self.W_sp
        x = x.view(B, C, H // H_sp, H_sp, W // W_sp,
                   W_sp)  # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,
                                                             W_sp)  ### B', C, H', W' # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]

        lepe = func(
            x)  ### B', C, H', W' # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14]  / [N*1*1 512 7 7]
        lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,
                                                                                          2).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]

        x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,
                                                                                              2).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        return x, lepe

    def forward(self, qkv):
        """
        x: B L C
        """
        q, k, v = qkv[0], qkv[1], qkv[2]  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]

        ### Img2Window
        H = W = self.resolution  # 56 28 14 7
        B, L, C = q.shape  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
        assert L == H * W, "flatten img_tokens has wrong size"

        q = self.im2cswin(q)  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        k = self.im2cswin(k)  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        v, lepe = self.get_lepe(v, self.get_v)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))  # B head N C @ B head C N --> B head N N
        attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)
        attn = self.attn_drop(attn)

        x = (attn @ v) + lepe
        x = x.transpose(1, 2).reshape(-1, self.H_sp * self.W_sp,
                                      C)  # B head N N @ B head N C # [N*56*1 56 32] [N*14*1 56 64] [N*2*1 98 128] [N*1*1 49 512]

        ### Window2Img
        x = windows2img(x, self.H_sp, self.W_sp, H, W).view(B, -1, C)  # B H' W' C

        return x  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]

代码部分其实和Swin类似,如果理解了swin的分窗机制,再加上head分组,基本上就能很快理解论文中思想。 

1.3 Locally-Enhanced Positional Encoding(LePE)

        因为Transformer是输入顺序无关的,因此需要向其中加入位置编码。上图左边为ViT模型的PE,使用的绝对位置编码或者是条件位置编码,只在embedding的时候与token一起进入transformer,中间的是Swin,CrossFormer等模型的PE,使用相对位置编码偏差,通过引入token图的权重来和attention一起计算,灵活度更好,相对APE效果更好。

        本文所提出的LePE,相比于RPE更加直接,将位置信息施加到线性投影中,同时注意到RPE以head方式引入偏差,而LepE是per-channel bias,这可能显示出更强大的潜力来充当位置嵌入。也就是直接将位置编码添加加到了Value向量上,假设位置编码为E,它的添加方式是通过将位置编码EV相乘完成的。然后通过一个short-cut将添加了位置编码的V和通过自注意力加权的V单位加到一起,公式如下:

        这里作者基于一个假设:对于一个输入元素,他附近的元素提供最重要的位置信息。所以对V做一个深度卷积V,加到softmax之后的结果上。公式为:

        这样,LePE可以友好地应用于将任意输入分辨率作为输入的下游任务。

    def get_lepe(self, x, func):
        # func -> self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim)
        B, N, C = x.shape  # [N 3136 32] [N 784 64] [N 196 128] [N 49 512]
        H = W = int(np.sqrt(N))
        x = x.transpose(-2, -1).contiguous().view(B, C, H, W)  # [N 32 56 56] [N 64 28 28] [N 128 14 14] [N 512 7 7]

        H_sp, W_sp = self.H_sp, self.W_sp
        x = x.view(B, C, H // H_sp, H_sp, W // W_sp,
                   W_sp)  # [N 32 1 56 56 1] [N 32 56 1 1 56] / [N 64 1 28 14 2] [N 64 14 2 1 28] / [N 128 1 14 2 7] [N 128 2 7 1 14] / [N 512 1 7 1 7]
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp,
                                                             W_sp)  ### B', C, H', W' # [N*56*1 32 56 1][N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14] / [N*1*1 512 7 7]

        lepe = func(
            x)  ### B', C, H', W' # [N*56*1 32 56 1] [N*56*1 32 1 56] / [N*14*1 64 28 2][N*14*1 64 2 28] / [N*2*1 128 14 7][N*2*1 128 7 14]  / [N*1*1 512 7 7]
        lepe = lepe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3,
                                                                                          2).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]

        x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3,
                                                                                              2).contiguous()  # [N*56*1 1 56 32] [N*14*1 2 56 32] [N*2*1 4 98 32] [N*1*1 16 49 32]
        return x, lepe

1.4 CSWin Transformer Block

        CSWin Transformer Block的结构如图所示,它最显著的特点是添加了两个shortcut,并使用LN对特征做归一化.

网络结构配置:

        其中X^{l}为第 l 个Transformer block的输出或各stage的卷积层。 

        CSwin的block有两个部分,一个是做LayerNorm和Cross-shaped window self-attention并接一个shortcut,另一个则是做LayerNorm和MLP,相比于Swin和Twins来说,block的计算量大大的降低了(swin,twins则是有两个attention+两个MLP堆叠一个block)。

class CSWinBlock(nn.Module):

    def __init__(self, dim, reso, num_heads,
                 split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None,
                 drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 last_stage=False):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.patches_resolution = reso
        self.split_size = split_size
        self.mlp_ratio = mlp_ratio
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.norm1 = norm_layer(dim)

        if self.patches_resolution == split_size:
            last_stage = True
        if last_stage:
            self.branch_num = 1
        else:
            self.branch_num = 2
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(drop)
        
        if last_stage:
            self.attns = nn.ModuleList([
                LePEAttention(
                    dim, resolution=self.patches_resolution, idx = -1,
                    split_size=split_size, num_heads=num_heads, dim_out=dim,
                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
                for i in range(self.branch_num)])
        else:
            self.attns = nn.ModuleList([
                LePEAttention(
                    dim//2, resolution=self.patches_resolution, idx = i,
                    split_size=split_size, num_heads=num_heads//2, dim_out=dim//2,
                    qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
                for i in range(self.branch_num)])
        

        mlp_hidden_dim = int(dim * mlp_ratio)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop)
        self.norm2 = norm_layer(dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """

        H = W = self.patches_resolution # 56
        B, L, C = x.shape # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]
        assert L == H * W, "flatten img_tokens has wrong size"
        img = self.norm1(x)
        qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # [3 N 3136 64] [3 N 784 128] [3 N 196 256] [3 N 49 512]
        
        if self.branch_num == 2:
            x1 = self.attns[0](qkv[:,:,:,:C//2]) # qkv[3 N 3136 32]->x1[N 3136 32] qkv[3 N 784 128]->x1[N 784 64] qkv[3 N 196 256]->x1[N 196 128]
            x2 = self.attns[1](qkv[:,:,:,C//2:]) # qkv[3 N 3136 32]->x2[N 3136 32] qkv[3 N 784 128]->x1[N 784 64] qkv[3 N 196 256]->x1[N 196 128]
            attened_x = torch.cat([x1,x2], dim=2)
        else:
            attened_x = self.attns[0](qkv) # [3 N 49 512]->[N 49 512]
        attened_x = self.proj(attened_x)
        x = x + self.drop_path(attened_x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x # [N 3136 64] [N 784 128] [N 196 256] [N 49 512]

在相似网络参数和计算量的模型中,cswin在分类任务和各类下游任务中都做到了SOTA

 检测:

分割:

猜你喜欢

转载自blog.csdn.net/athrunsunny/article/details/133772022