基础论文学习(3)——SwinTransformer

目前Transformer应用到图像领域的挑战:

  • 图像分辨率高,像素点多,如果需要更多特征就必须构建很长的序列,但Transformer基于全局自注意力的计算导致计算量较大,能否用窗口+分层的形式代替长序列,实现类似CNN感受野的效果?

针对上述问题,我们提出了一种包含滑窗操作,具有层级设计的Swin Transformer,逐层合并tokens。
在这里插入图片描述

其中滑窗操作包括不重叠的local window + 重叠的cross-window将注意力计算限制在一个窗口中,一方面能引入CNN卷积操作的局部性,另一方面能节省计算量

1. SwinTransformer总体架构

整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。

  • 在输入开始的时候,做了一个Patch Embedding,将图片切成一个个图块(对image进行卷积,然后对特征图切分为patch),并嵌入到Embedding,构建token序列。
  • 在每个Stage里,由Patch Merging和多个Block组成。
  • 其中Patch Merging模块主要在每个Stage一开始进行下采样(W和H不断减小,C不断增大),降低图片分辨率。
  • 而Block具体结构如右图所示,主要是LayerNormMLPWindow AttentionShifted Window Attention组成 (提供了2种attention计算方法)
    在这里插入图片描述
class SwinTransformer(nn.Module):
    def __init__(...):
        super().__init__()
        ...
        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            
        self.pos_drop = nn.Dropout(p=drop_rate)

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(...)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
# step1:Patch Embedding
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        
# step2:BasicLayer = feature_shift + Window Partition + W-MSA/SW-MSA + Window Reverse + reverse_shift + Patch Merging
        for layer in self.layers:  # 遍历4个stage
            x = layer(x)  

# step3:LN + AvgPool + flatten
        x = self.norm(x)  # (B, L, C)->(B, L, C)
        x = self.avgpool(x.transpose(1, 2))  # (B, L, C)->(B, C, 1)
        x = torch.flatten(x, 1)  # (B, C, 1)->(B, C)
        
        return x

    def forward(self, x):
        x = self.forward_features(x)
# step4:FC(不同任务的Head层不同)
        x = self.head(x)  # # (B, C)->(B, num_class)
        return x

其中有几个地方处理方法与ViT不同:

  • ViT在输入会给embedding进行位置编码。而Swin-T这里则是作为一个可选项(self.ape),Swin-T是在计算Attention的时候做了一个相对位置编码
  • ViT会单独加上一个可学习参数,作为分类的token。而Swin-T则是直接做平均,输出分类,有点类似CNN最后的全局平均池化层

1.1 Patch Embedding

在输入进Block前,我们需要将图片切成一个个patch,然后嵌入向量。

具体做法是对原始图片(224,224,3)裁成一个个 patch_size * patch_size的窗口大小,然后进行嵌入。

这里可以将stride=4,kernel_size=4设置为patch_size=4大小,按照VIT中patch embedding的方式(不重叠卷积)得到每一个图像块patch对应长度为embed_dim的向量。设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度。输出(3136, 96)相当于3136个长度为96的token,j将tokens序列排列为正方形即(56*56, 96)

import torch
import torch.nn as nn


class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size) # -> (img_size, img_size)
        patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)  # 这里!!
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        # 假设采取默认参数
        x = self.proj(x) # 出来的是(N, 96, 224/4, 224/4) 
        x = torch.flatten(x, 2) # 把HW维展开,(N, 96, 56*56)
        x = torch.transpose(x, 1, 2)  # 把通道维放到最后 (N, 56*56, 96)
        if self.norm is not None:
            x = self.norm(x)
        return x

1.2 Window Partition/Reverse

window partition函数是用于对张量按非重叠窗口大学window_size划分为一条条tokens,指定窗口大小。将原本的张量从 N H W C, 划分成 num_windows*B, window_size, window_size, C,其中 num_windows = H*W / (window_size*window_size),即窗口的个数。

如输入特征图(56,56,96),默认window_size=7x7,所以分为8x8个窗口,num_windows=64,输出特征图(64, 7, 7, 96),之前的单位是token(共56x56=3136个token),现在的单位是窗口(共8x8=64个window,每个window聚集了7x7=49个token),最后把每个window内的token聚合展平为一个大token,每个大token的shape=(49,96)

window reverse函数则是对应的逆过程。这两个函数会在后面的Window Attention用到。
在这里插入图片描述
实现起来,window partition和window reverse没有可学习参数,因而不需要继承其他的类,写成函数就行。上面windows_partition是将送进来的特征进行window_size的划分,最终变为一条条tokens(对应示意图!!!)

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

1.3 W-MSA 和 SW-MSA

两者串联起来就是一个Swin Transformer Block:

  • W-MSA 窗口多头自注意力机制(windows multi-head self attention):窗口内部multi-head self-attention
  • SW-MSA 滑动窗口多头自注意力机制(shift windows multi-head self attention):窗口之间multi-head self-attention
    在这里插入图片描述

W-MSA

传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量。

输入特征图(64, 7, 7, 96),window size=7(包含7x7个长度96的token),共64个窗口。
在这里插入图片描述

swin transformer是按照window size内的小方格计算self-attention的,比如上图中的windows size=7,也就是每7*7个tokens(红色框)之间计算多头self-attention(head=3)。

一次性用x计算出qkv三个矩阵:3个qkv矩阵放在一起的shape=(3, 64, 3, 49, 32),3个矩阵,64个window,head=3, 窗口大小=7x7=49,每个head特征长度96/3=32,64个窗口自己的attention结果是(64, 3, 49, 49)。

这里注意,计算self-attention的输入tokens的数量和维度都不变换,因此最终的输出特征图依旧是(64, 49, 96),64个窗口,每个窗口7x7个token,每个96维的token都会学习到了窗口内的自注意力。

SW-MSA

前面的Window Attention是在每个窗口下计算注意力的,为了更好的和其他window进行信息交互,Swin Transformer还引入了shifted window操作。

左边是没有重叠的Window Attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了,由原本4个窗口变成了9个窗口。

在这里插入图片描述
在实际代码里,我们是通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价
在这里插入图片描述

特征图移位+Mask操作

对特征图位移(torch.roll)之后,还是按照4个窗口计算attention,但是会有冗余计算结果,直接设置对应位置mask为负无穷(softmax后为0),忽略不需要的attetion部分(图中灰色部分),输出的结果同W-MSA 也是(56, 56, 96,不要忘记计算完对特征图还原平移)。
在这里插入图片描述

我们看下Block的前向代码:

def forward(self, x):
    H, W = self.input_resolution
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"

    shortcut = x
    x = self.norm1(x)
    x = x.view(B, H, W, C)

    # cyclic shift
    if self.shift_size > 0:
        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    else:
        shifted_x = x

    # partition windows
    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

    # W-MSA/SW-MSA
    attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

    # merge windows
    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
    shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

    # reverse cyclic shift
    if self.shift_size > 0:
        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    else:
        x = shifted_x
    x = x.view(B, H * W, C)

    # FFN
    x = shortcut + self.drop_path(x)
    x = x + self.drop_path(self.mlp(self.norm2(x)))

    return x

整体流程如下

  • 先对特征图进行LayerNorm
  • 通过self.shift_size决定是否需要对特征图进行shift
  • 然后将特征图切成一个个窗口
  • 计算Attention,通过self.attn_mask来区分Window Attention还是Shift Window Attention
  • 将各个窗口合并回来
  • 如果之前有做shift操作,此时进行reverse shift,把之前的shift操作恢复
  • 做dropout和残差连接
  • 再通过一层LayerNorm+全连接层,以及dropout和残差连接

1.4 Patch Merging

该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数 进而形成层次化的设计,同时也能节省一定运算量。
在这里插入图片描述

在CNN中,则是在每个Stage开始前用stride=2的卷积/池化层来降低分辨率。

每次降采样是两倍,因此在行方向和列方向上,间隔2选取元素。

然后拼接在一起作为一整个张量,最后展开。此时通道维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接层再调整通道维度为原来的两倍。如输入(56, 56, c),变为(28, 28, 4c),全连接输出(28, 28, 2c),这样就使得下一个stage的窗口数量减少了。

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({
      
      H}*{
      
      W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

下面是一个示意图(输入张量B=1, H=W=8, C=1,不包含最后的全连接层调整)
在这里插入图片描述

2. 实验分析

在这里插入图片描述
在ImageNet22K数据集上,准确率能达到惊人的86.4%。另外在检测,分割等任务上表现也很优异。这篇文章创新点很棒,引入window这一个概念,将CNN的局部性引入,还能控制模型整体计算量。在Shift Window Attention部分,用一个mask和移位操作,很巧妙的实现计算等价。作者的代码也写得十分赏心悦目,推荐阅读!

猜你喜欢

转载自blog.csdn.net/weixin_54338498/article/details/132412378
今日推荐