Swin Transformer介绍

Swin Transformer发表于ICCV 2021,获得最佳论文,其作者都来自微软亚洲研究院。

Transformer最初是处理NLP(自然语言处理)领域的任务,获得了巨大的成功。逐渐向计算机视觉领域进行拓展,有DETR、ViT以及ViT的变种,使得Transformer在计算机视觉领域大放异彩。

Transformer架构图

Transformer由Encoder和Decoder组成,它是文本语句转化为词向量的一系列处理。

因为语句中的词是有先后顺序的,在计算机视觉领域,为了使用Transformer,需要将一张图片进行切分成若干Patch,再加上一个额外可学习的分类编码。送入Transformer Encoder(多层堆叠),再加上位置编码。再进入MLP Head(多层感知机),最后进行具体的分类。在ViT中只是使用了Transformer的Encoder而没有使用Decoder。ViT的问题:它没有考虑文本和视觉信号的不同,它只能做图像分类,对于目标检测和图像分割没有相应的尝试。

Swin Transformer提供了更加通用的基于Transformer的计算机视觉任务的主干网络,并且能应用到多种计算机视觉任务中,如图像分类、目标检测、语义分割、实例分割等任务。甚至在某些方面性能超过了传统的CNN。

Swin Transformer是根据ViT发展而来的,在ViT中只用了16倍的下采样,经过Transformer Block中,它的形状保持不变,并且主要用于图像的分类。而Swin Transformer开始的时候是使用4倍的下采样,也就是4*4的patch下采样后变成1个像素点。然后还可以做8倍、16倍下采样,可以达到多尺度的特征图的提取效果,它不仅可以用于图像分类,还可以用于目标检测和图像分割。

扫描二维码关注公众号,回复: 14249180 查看本文章

其实Swin Transformer的很多思想和CNN有异曲同工之处,它利用了视觉信号的好的先验,它的网络架构中也采用了层次化(hierarchy)、局部化(locality)、平移不变性(translation invariance)。我们知道CNN可以进行多尺度,层级化的特征提取,主要代表为YOLOV3的FPN网络。在Swin Transformer中,我们也可以看到不同的下采样尺度,它的特征提取的颗粒度在不断的变大,由浅层到深层,它的感受野相比于浅层也是逐渐的扩大。局部性主要体现在它的注意力的计算主要是在窗口中进行的,而ViT是在整个特征图上进行注意力的计算,这样Swin Transformer计算的复杂度就会大大的降低。

主要技术创新

Swin Transformer的主要技术创新就是采用了局部化偏移窗口(Shifted windows)。它是采用非重叠的窗口进行自注意力计算,有关自注意力机制可以参考计算机视觉中的注意力机制 。这种非重叠窗口是在每个尺度的feature map的窗口中进行局部化的自注意力的计算,但是在不同的尺度(层)之间一直是局部化的计算,就缺少窗口之间的信息的交互,所以它还采用了不同层级的窗口的偏移,不仅包含了W_MSA(窗口自注意力机制),还有SW_MSA(偏移窗口自注意力机制)。

Shifted windows技术

在上图的Layer l中有四个蓝色的窗口,自注意计算就是在这种局部非重叠窗口进行的。不同query会共享同样key集合,减少计算量,从而对硬件友好。在Layer l+1中,在前后两层的Transformer模块中,非重叠窗口的配置相比前一层做了半个窗口的移位,也就是蓝色窗口发生了移动,使得上一层中不同窗口信息进行了交换。图中灰色的格子就是一个patch(小块),它是一个4*4像素的大小;红色的格子是一个局部窗口,它是进行自注意的计算,一般包含7*7个patch大小(这里跟图上不同)。我们看到从Layer l到Layer l+1的过程中,蓝色窗口向右下移动了半个窗口的距离,从而使得layer l+1的第一个蓝色窗口有原先4个窗口的信息,这样Layer l+1层相对Layer l层就有相邻窗口之间的信息交互。

Swin Transformer网络架构

首先图片送入网络,先经过块状分区(Patch Partition),再经过线性嵌入(Linear Embedding),总称Patch Embedding。再送入Swin Transformer Block。每个Swin Transformer Block是由两个连续的Swin Transformer Blocks所组成(见最右边),也就是我们上面说的Layer l和Layeer l+1层,其中Layer l层包含的是W-MSA(窗口自注意力机制),而layer l+1层包含的是SW-MSA(偏移窗口自注意力机制)。然后第一个Stage的输出再送到块状拼接(Patch Merging),再送到Swin Transformer Block,这是Stage 2,后面的Stage跟Stage 2是一样的了,只不过Stage 3的Swin Transformer Block不是2个而是6个,表示有三个成对的Layer l和Layer l+1层。

我们再从图片尺寸的角度来看一下整个过程,我们假设送入网络的图片是224*224*3,由于每个patch是4*4像素的大小,那么经过块状分区(Patch Partition)后,就变成了(224/4)*(224/4)*(4*4*3)=56*56*48的尺寸。再经过线性嵌入(Linear Embedding)后,通道数翻倍,就变成了56*56*96=3136*96的尺寸。由于一个窗口有7*7个patch,所以总共有(56/7)*(56/7)=8*8=64个窗口。但是Stage 1经过2个Swin Transformer Block,做了窗口滑动后输出的尺寸依然为56*56*96。在Stage 2中经过块状拼接(Patch Merging)后,尺寸减半,通道数翻倍,变成了28*28*192,再经过2个Swin Transformer Block,输出28*28*192。在Stage 3中,经过块状拼接(Patch Merging)后,尺寸减半,通道数翻倍,变成了14*14*384,再经过6个Swin Transformer Block,也就是窗口滑动了3次,输出14*14*384。在Stage 4中,经过块状拼接(Patch Merging)后,尺寸减半,通道数翻倍,变成了7*7*768,再经过2个Swin Transformer Block,输出7*7*768。

块状拼接(Patch Merging)

我们来看一下块状拼接的过程,假设输入的是一个4*4的feature map,我们对像素进行标号,这个标号不是指的像素值而是像素的位置。这些标号都是相隔一个位置来进行标记。我们把相同标号的像素点进行抽取,将其拆分成4个2*2的小块,每一个小块的标号都相同。将这4个小块进行拼接变成2*2*4的尺寸,4是通道数,再做LayerNorm操作(在通道方向上做像素值的归一化);再经过一个全连接层,转化为2*2*2的feature map。整个过程就是在输入的feature map基础上,尺寸减半,通道数翻倍

Swin Transformer Block

这里我们假设传入Swin Transformer Block的feature map为,它经过了LN(LayerNorm通道层归一化)以及MSA(多头自注意力机制)之后再与直连相加,得到再经过一个LN与MLP(多层感知机)之后再与直连相加得到,整个过程的公式如下

MLP(多层感知机) Block

它是将输入经过全连接层,再经过一个GELU的激活函数,再经过一个Dropout,再经过一个全连接层,再经过一个Dropout。在第一个全连接层,输入通道数为768,输出通道数为4*768=3072。经过第二个全连接层,它的通道数又变回768。

GELU激活函数同ReLU和ELU的比较。

Swin T(Tiny),S(Small),B(Base),L(Large)

Swin Transformer有四个尺寸的模型,win.sz.7x7表示使用的窗口(Windows)的大小,这四个模型使用的都是7*7的窗口;dim表示feature map的通道数,我们可以看到在Swin T的stage 1中的输出通道数为96,Swin S为96,Swin B为128,Swin L为192。head表示多头注意力模块中head的个数,在Swin T的stage 1、2、3、4中的头数分别为3、6、12、24;Swin S相同,Swin B变成了4、8、16、32,Swin L变成了6、12、24、48。成对的Swin Transformer Block在Swin T中的Stage 1、2、3、4中分别为2、2、6、2;Swin S中为2、2、18、2;Swin B、Swin L与Swin S相同。所以从T到L是不断变大的网络模型。具体变化的参数如下

这里C为通道数,layer numbers是Swin Transformer Block在每一个Stage中的数量。

Swin Transformer可以和CNN做一个对比,其中Patch Merging的主要功能是下采样,可以和CNN中的池化做一个类比;Swin Transformer Block可以和CNN中的卷积层做一个类比。不同的Stage进行堆叠,CNN中也经常会进行池化、卷积、池化、卷积.......的堆叠。所以Swin Transformer也是借鉴了CNN中的一些理念或先验。当然它们的实现方式是不同的。

SW-MSA(偏移窗口自注意力机制)

在上图中,我们看到在Layer l中有4个窗口,做Shifted windows操作,就变成了Layer l+1中的9个窗口,并且窗口中patch的数目也不一样。

当4个窗口变成9个窗口后,做自注意力的时候,有一个Naive Solution(简单的想法),就是简单的将周围进行填充,使得所有窗口中的patch数目相同。然后再与中间窗口的大小对所有窗口进行自注意力的计算。周围的部分可以填充0,但是这种想法的计算量会增加,对于效率来说并不可取。

更加有效的进行自注意力批计算的方法

对于9个不同大小的窗口,如果变成相同大小的窗口的时候,我们可以做一个cyclic shift(循环移位),把A部分移动到右下;把B部分移动到右边;把C部分移动到下面。这时我们会发现,我们又获得了4个红色框的窗口,它们的大小都一样,patch数目是一样的。此时我们对这4个窗口做自注意力的计算,通过蒙版去掉不需要的部分。计算结束后,只把保留的部分再通过reverse cyclic shift(反向循环移位),将A、B、C部分还原到原来的特征图上。

特征图移位操作

代码里特征图移位是通过torch.roll来实现的,下面是示意图

上图中,我们先使用roll(x, shirts=-1, dims=0)将第一排数值移动到最下面,再使用roll(x, shirts=-1, dims=1)将变换后的第二张图中的第一列移动到最右边。

import torch

if __name__ == '__main__':

    a = torch.rand(4, 4)
    print(a)
    b = torch.roll(a, shifts=-1, dims=0)
    print(b)
    c = torch.roll(b, shifts=-1, dims=1)
    print(c)

运行结果

tensor([[0.9839, 0.6769, 0.1544, 0.2172],
        [0.5940, 0.5845, 0.8061, 0.6111],
        [0.2948, 0.6773, 0.1506, 0.9417],
        [0.3512, 0.2851, 0.8264, 0.3154]])
tensor([[0.5940, 0.5845, 0.8061, 0.6111],
        [0.2948, 0.6773, 0.1506, 0.9417],
        [0.3512, 0.2851, 0.8264, 0.3154],
        [0.9839, 0.6769, 0.1544, 0.2172]])
tensor([[0.5845, 0.8061, 0.6111, 0.5940],
        [0.6773, 0.1506, 0.9417, 0.2948],
        [0.2851, 0.8264, 0.3154, 0.3512],
        [0.6769, 0.1544, 0.2172, 0.9839]])

经过上面的变动后

原本的9个窗口经过重新排列,就变成后上图中中间的图的样子,再将该图重新4等分成右边的红色窗口。我们以包含5、3的红色窗口为例来说明它的自注意力机制

我们将该窗口中的4个patch分别展开成行向量和列向量,如上图中的中间和右边的图所示

我们将这两个向量进行两两组合,得到一个矩阵。由于是自注意力机制,我们只关注5或者3本身,所以我们将都为5或者都为3的保留,设为1;将不同的去掉,设为0。这个过程被称为Mask MSA。当然我们这里只是以4*4的窗口来说明,而在Swin Transformer中一个窗口的大小是7*7的。

在上图中有四个窗口,其中Window0因为只有一个0,所以我们不需要Mask。而Window1、Window2、Window3都需要做Mask,并且它们Mask的模式(Patten)是不一样的。Window1根据之前的分析可知,它是条纹状的Mask;Window2是分块的Mask,我们可以把去掉的部分加上一个较大的负数,比如说-100,再经过softmax变成0。Window3有4个值,它的Mask是Window1和Window2的模式的融合。这就称为Attn Mask

相对位置偏移(Relative Position Bias)详解

在计算自注意力的时候,位置变化不是加在输入上,而是结合在Attention的计算之中。

上式中的Q、K、V分别代表查询向量矩阵、键向量矩阵、值向量矩阵,这三个矩阵中每一行分别代表一个对应的向量。上式中的B就是相对位置偏移。

相对位置索引(Relative Position Index)

上图中的窗口中有2*2个patch,我们分别给这四个位置标上绝对位置索引,分别为(0,0)、(0,1)、(1,0)、(1,1),第一个序号代表行,第二个序号代表列。以左上角的(0,0)为基准,相对位置索引指的是其他的位置的索引表示为用(0,0)减去该位置的绝对位置索引得到的值,如原本为(0,1)的相对位置索引为(0,0)-(0,1)=(0,-1),其余的相同,这里需要注意(0,0)位置本身也要减去(0,0),只不过它不变罢了。

我们分别以不同的位置作为基准,可以分别得到其他位置的相对位置索引。如以(0,1)位置作为基准,(0,1)-(0,0)=(0,1),(0,1)-(0,1)=(0,0),其他的以此类推。最终我们将各个相对位置索引展开成一个行向量,再进行拼接就得到了上图中下面的矩阵。

再将该矩阵加上一个M-1,M为窗口的大小,在Swin Transformer中为7,我们这里为2。这样就得到了上图中第二个矩阵中的值,再将每一个行标都乘以2M-1,得到上图中第三个矩阵中的值。最后将行标和列标求和,就得到最后一个矩阵的值,这个矩阵中的值就是相对位置索引

这个相对位置索引需要去索引的值会有一个相对位置偏置表(relative position bias table)

这个表的元素的个数为(2M-1)*(2M-1),在Swin Transformer中为13*13=169个,我们这里为3*3=9个。

MSA计算复杂度

这里给出了MSA和W-MSA的计算复杂度,这里h、w代表特征图的高度和宽度。在多头自注意力MSA的复杂度公式中,有一个(hw)^2项;而采用窗口自注意力W-MSA的复杂度公式中则没有该平方项。如果有平方项的话,则该值比较大,所以采用W-MSA的话,复杂度就会大幅度降低。在W-MSA中有一个M,该M表示窗口的patch数量,在Swin Transformer中M=7。

我们来看一下MSA的计算复杂度公式是怎么来的,在自注意力Attention自身的公式中,Q、K、V是通过特征图X乘以三个矩阵得来的

X是hwC,而三个矩阵都是C*C的,每一个乘法的时间复杂度为hwC^2,由于有3个,所以这一步的时间复杂度为3hwC^2。

在求得了Q、K、V之后,就到了的时间复杂度,我们知道Q为hwC,为Chw,所以的时间复杂度为。我们忽略掉除以和softmax的时间复杂度,又与V有一个矩阵乘,它的时间复杂度也为,所以这一步总的时间复杂度为

最后一步,多头注意力模块比单头注意力模块多了最后一个融合矩阵运算,就是这个Linear全连接层的运算。它的时间复杂度也是hwC^2,所以全部过程的时间复杂度为

W-MSA的计算复杂度

上面我们知道了

对于窗口自注意力机制中的每一个窗口的h=M,w=M,M为7。所以每个窗口的时间复杂度为

由于整个特征图有(h/M)*(w/M)个窗口,这里h、w代表整个特征图的高和宽。所以整个特征图的时间复杂度为

故最终

Swin Transformer实验结果及性能

Swin Transformer对三个方向进行了实验,一个是图像分类,使用的是ImageNet 1K分类的数据集;一个是图像检测,使用的是coco物体检测数据集;一个是像素级的语义分割,使用是ADE20K数据集。

上表是对Image-1K的数据集的实验结果,RegNet是FaceBook搜索出来的一个比较好的网络模型,EffNet是谷歌的模型,它们都是CNN的网络模型。下面是基于Transformer的网络模型,包括ViT,DeiT,Swin Transformer。我们可以看到ViT比CNN的精度还是要低的;DeiT用了更多的数据增强,所以它的性能有所提高,与CNN的性能比较接近;而使用了Swin Transformer后的精度已经超过了CNN。

上表是在ImageNet-22K上做预训练,在ImageNet-1K上做训练的结果。我们可以看到ViT的性能也上来了,接近CNN的网络性能。而使用了Swin Transformer后有了更大的提升。

在目标检测上的性能,我们可以看到在a部分几种主要的目标检测算法Mask RCNN、ATSS、RepPointsV2以及Sparse RCNN中,我们将主干网络都由ResNet50替换成Swin Transformer后,它们的性能都有所提升。在b部分,我们固定一种Mask RCNN后,采用几种不同规格的Swin Transformer,Swin-B的性能是最好的。

在语义分割的性能上,相比于传统的CNN的网络,使用了Swin Transformer作为主干网络的UperNet也取得了比较好的性能。

最后是关于消融实验,主要对比的是三个数据集,我们看到采用了shited windows技术,它们的精度都有提升。另外采用相对位置偏置的方法也比不采用的方法,精度都有提升。相比于图像分类,目标检测和语义分割有更大的提升,这是因为目标检测和语义分割它们对位置信息比图像分类更加的敏感

代码解析

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from typing import Optional

我们直接从Swin Transformer的主代码入手

class SwinTransformer(nn.Module):
    """ Swin Transformer网络模型
    """

    def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        """
        patch_size:一个patch中的像素点数
        in_chans:进入网络的图片通道数
        num_classes:分类数量
        embed_dim:feature map通道数
        depths:各个Stage中,Swin Transformer Block的数量
        num_heads:多头注意力各个Stage中的头数
        window_size:窗口自注意力机制的窗口中的patch数
        mlp_ratio:多层感知机模块中第一个全连接层输出的通道倍数
        qkv_bias:如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置
        drop_rate:dropout rate,默认为0
        attn_drop_rate:用于自注意力机制中的dropout rate,默认为0
        drop_path_rate:在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括
        LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0.1
        norm_layer:通道方向归一化
        patch_norm:如果是True的话,在patch embedding之后加上归一化
        use_checkpoint:是否使用Pytorch中间数据保存机制
        """
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        # stage4输出特征的通道数(Swin-Tiny:768)
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # 把图像分割成不重叠的patch
        self.patch_embed = PatchEmbed(
            patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        self.pos_drop = nn.Dropout(p=drop_rate)

这里我们来看一下PatchEmbed类的实现

class PatchEmbed(nn.Module):
    """
    包括块状分区(Patch Partition)和线性嵌入(Linear Embedding)
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        # 使用一个卷积来完成块状分区和线性嵌入两个过程,4倍降采样,输出通道数变成96
        # 卷积核大小与步长相同,不会有重叠区域被划分
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        # 通道归一化
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # 在左、右、上、下、前、后进行填充
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))

        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        # 在通道方向进行归一化
        x = self.norm(x)
        return x, H, W

回到SwinTransformer类中

# stochastic depth (Swin-Tiny最大的depths数目=2+2+6+2)
# 不同的stage,舍弃整个直连分支的概率不同,从小到大,最小为0,最大为0.1
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):  # layer相当于stage
    # 这里的stage不包含本stage的patch_merging层,包含的是下个stage的patch_merging层
    layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                        depth=depths[i_layer],
                        num_heads=num_heads[i_layer],
                        window_size=window_size,
                        mlp_ratio=self.mlp_ratio,
                        qkv_bias=qkv_bias,
                        drop=drop_rate,
                        attn_drop=attn_drop_rate,
                        drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                        norm_layer=norm_layer,
                        # 只有前3个stage有patchmerging,最后一个没有
                        downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                        use_checkpoint=use_checkpoint)
    self.layers.append(layers)

这里我们来看一下BasicLayer类

class BasicLayer(nn.Module):
    """
    A basic Swin Transformer layer for one stage.
    一个BasicLayer包含偶数个SwinTransformerBlock和一个downsample层(即Patch Merging层)
    """

    def __init__(self, dim, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
        """
        dim: feature map通道数
        depth:各个Stage中,Swin Transformer Block的数量
        num_heads:多头注意力各个Stage中的头数
        window_size:窗口自注意力机制的窗口中的patch数
        mlp_ratio:多层感知机模块中第一个全连接层输出的通道倍数
        qkv_bias:如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置
        drop:dropout rate,默认为0
        attn_drop:用于自注意力机制中的dropout rate,默认为0
        drop_path:在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括
        LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0
        norm_layer:通道方向归一化
        downsample:使用Patch Merging来降采样
        use_checkpoint:是否使用Pytorch中间数据保存机制
        """
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint
        # shift-window尺寸,大小为3
        self.shift_size = window_size // 2

        # build SwinTransformerBlock
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                # 用于区分是使用W-MSA还是SW-MSA,0为W-MSA,1为SW-MSA
                shift_size=0 if (i % 2 == 0) else self.shift_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        # 当stage=4的时候为None
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

这里我们来看一下SwinTransformerBlock类

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        """
        dim:feature map通道数
        num_heads:多头注意力各个Stage中的头数
        window_size:窗口自注意力机制的窗口中的patch数
        shift_size:shift-window尺寸
        mlp_ratio:多层感知机模块中第一个全连接层输出的通道倍数
        qkv_bias:如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置
        drop:dropout rate,默认为0
        attn_drop:用于自注意力机制中的dropout rate,默认为0
        drop_path:在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括
        LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0
        act_layer:多层感知机模块中第一个全连接层后的激活函数
        norm_layer:通道方向归一化
        """
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        # 窗口自注意力机制
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        # 多层感知机经过第一个全连接层的通道数
        mlp_hidden_dim = int(dim * mlp_ratio)
        # 多层感知机
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

这里我们来看一下WindowAttention类

class WindowAttention(nn.Module):
    """ 多头窗口自注意力机制(W-MSA)
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        """
        dim:feature map通道数
        window_size:窗口自注意力机制的窗口中的patch数
        num_heads:多头注意力各个Stage中的头数
        qkv_bias:如果是True的话,对自注意力公式中的Q、K、V增加一个可学习的偏置
        attn_drop:用于自注意力机制中的dropout rate,默认为0
        proj_drop:第二个全连接层的dropout rate,默认为0
        """

        super().__init__()
        # 96
        self.dim = dim
        # [7, 7]
        self.window_size = window_size  # [Mh, Mw]
        # 3
        self.num_heads = num_heads
        # 32,Attention公式中的d
        head_dim = dim // num_heads
        # 1/√d
        self.scale = head_dim ** -0.5

        # 定义一个相对位置偏移(relative position bias)参数表
        # 大小为(2M - 1)*(2M - 1),M就是window_size
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]

        # 获取窗口内每个标记的成对相对位置索引
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        # 标上绝对位置索引矩阵
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
        # 展平成向量
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]
        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        # 以每一个点为基准获取相对位置索引矩阵
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]
        # 该矩阵加上一个M-1,M为window_size
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        # 每个行标都乘以2M-1,M为window_size
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        # 将行标和列标求和,得到相对位置索引
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
        # 将相对位置索引进行注册,在训练的时候不需要改变
        self.register_buffer("relative_position_index", relative_position_index)
        # 使用全连接层即矩阵乘法来计算Q、K、V值
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        # 对相对位置偏移(relative position bias)参数表进行截断的正态分布初始化
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        x: 输入的特征图(num_windows*BatchSize, Mh*Mw, C)
        mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        # 获取特征图属性
        B_, N, C = x.shape
        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # 通过矩阵乘法来计算Q、K、V的值
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # 提取Q、K、V的值
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        # Q/√d
        q = q * self.scale
        # QK^T/√d
        attn = (q @ k.transpose(-2, -1))

        # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
        # QK ^ T /√d + B,B就是相对位置偏移
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            # 做softmax
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        # softmax后和V做矩阵乘法
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        # 再经过一个全连接层
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

我们看一下DropPath类

class DropPath(nn.Module):
    """丢弃直连分支,影响网络的深度
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path_f(x, self.drop_prob, self.training)

然后是drop_path_f这个方法

def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    """每个样本的下落路径(随机深度)(当应用于剩余块的主路径时)
    x:feature map
    drop_prob:在Swin Transformer Block中,有一定概率丢弃整个直连分支,包括
        LN、W-MSA或者SW-MSA,只保留直连的连接,是一种网络深度的随机性,默认为0.1
    training:是否为训练阶段
    """
    # 如果drop_prob为0或者为非训练,直接返回特征图
    if drop_prob == 0. or not training:
        return x
    # 获取保留直连分支的概率
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    # 将不丢弃的值扩大,类似于dropout的处理
    output = x.div(keep_prob) * random_tensor
    return output

我们再看一下Mlp类

class Mlp(nn.Module):
    """ 多层感知机
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

现在我们回到SwinTransformerBlock类

def forward(self, x, attn_mask):
    H, W = self.H, self.W
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"

    shortcut = x
    # 通过LN层
    x = self.norm1(x)
    x = x.view(B, H, W, C)

    # pad feature maps to multiples of window size
    # 把feature map给pad到window size的整数倍
    pad_l = pad_t = 0
    pad_r = (self.window_size - W % self.window_size) % self.window_size
    pad_b = (self.window_size - H % self.window_size) % self.window_size
    x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
    # 获取填充后的高度和宽度
    _, Hp, Wp, _ = x.shape

    # 如果是移动窗口自注意力机制(SW-MSA),做循环移位,这里是将上下移动和左右移动合并到一步完成
    if self.shift_size > 0:
        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    else:  # 否则不做处理
        shifted_x = x
        attn_mask = None

    # 对移动完的feature map进行窗口分区
    x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]

这里我们来看一下window_partition方法

def window_partition(x, window_size: int):
    """
    将feature map按照window_size划分成一个个不重叠的window
    Args:
        x: (B, H, W, C)
        window_size (int): window size(M)

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
    # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

SwinTransformerBlock类继续

# 进行窗口自注意力机制的计算(W-MSA/SW-MSA)
attn_windows = self.attn(x_windows, mask=attn_mask)  # [nW*B, Mh*Mw, C]

# 合并一个个7*7的窗口成一个新的feature map
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # [nW*B, Mh, Mw, C]
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # [B, H', W', C]

# 反向循环移位
if self.shift_size > 0:
    x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
    x = shifted_x

我们来看一下window_reverse方法

def window_reverse(windows, window_size: int, H: int, W: int):
    """
    将windows中的一个个窗口还原成一个feature map
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size(M)
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    # view: [B*num_windows, Mh, Mw, C] -> [B, H//Mh, W//Mw, Mh, Mw, C]
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # permute: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B, H//Mh, Mh, W//Mw, Mw, C]
    # view: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H, W, C]
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

SwinTransformerBlock类继续

if pad_r > 0 or pad_b > 0:
    # 把前面pad的数据移除掉
    x = x[:, :H, :W, :].contiguous()

x = x.view(B, H * W, C)

# 进行直连相加操作
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))

return x

现在我们回到BasicLayer类

def create_mask(self, x, H, W):
    # calculate attention mask for SW-MSA
    # 保证Hp和Wp是window_size的整数倍
    Hp = int(np.ceil(H / self.window_size)) * self.window_size
    Wp = int(np.ceil(W / self.window_size)) * self.window_size
    # 和feature map具有一样的通道排列顺序,方便后续window_partition
    img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
    h_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    w_slices = (slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None))
    # 给每个window编号
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1
    # mask_windows也进行window切分
    mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
    # 通过广播操作相减,index相同的attn_mask将变为0
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
    # [nW, Mh*Mw, Mh*Mw]
    # mask掉的部分改为-100,这样经过softmax后,对应位置的值就会很小
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    return attn_mask

def forward(self, x, H, W):
    attn_mask = self.create_mask(x, H, W)  # [nW, Mh*Mw, Mh*Mw]
    # 通过每一个SwinTransformerBlock
    for blk in self.blocks:
        blk.H, blk.W = H, W
        if not torch.jit.is_scripting() and self.use_checkpoint:
            x = checkpoint.checkpoint(blk, x, attn_mask)
        else:
            x = blk(x, attn_mask)
    # 进行块状拼接(PatchMerging)下采样
    if self.downsample is not None:
        x = self.downsample(x, H, W)
        H, W = (H + 1) // 2, (W + 1) // 2

    return x, H, W

这里的downsample就是PatchMerging类,我们来看一下该类

class PatchMerging(nn.Module):
    """ 块状拼接层,尺寸减半,通道数翻倍
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        """
        dim:feature map通道数
        norm_layer:通道方向归一化
        """
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
        # 拆分小块
        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        # 将小块拼接
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
        # 在通道方向上做像素值的归一化
        x = self.norm(x)
        # 经过一个全连接层
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

回到SwinTransformer类中

    # 4个stage之后,接一个通道归一化,平均池化层,最后跟一个输出多分类的全连接层
    self.norm = norm_layer(self.num_features)
    self.avgpool = nn.AdaptiveAvgPool1d(1)
    self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
    # 初始化权重和偏置
    self.apply(self._init_weights)

def _init_weights(self, m):
    """
    对全连接层或者通道归一化进行权重以及偏置的初始化
    """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

def forward(self, x):
    # x: [B, L, C]
    # 图像分割
    x, H, W = self.patch_embed(x)
    x = self.pos_drop(x)
    # 通过各个stage
    for layer in self.layers:
        x, H, W = layer(x, H, W)

    x = self.norm(x)  # [B, L, C]
    x = self.avgpool(x.transpose(1, 2))  # [B, C, 1]
    x = torch.flatten(x, 1)
    x = self.head(x)
    return x
{{o.name}}
{{m.name}}

猜你喜欢

转载自my.oschina.net/u/3768341/blog/5529722