CVPR 2023 | Making Vision Transformers Efficient from A Token Sparsification View

CVPR 2023 | Making Vision Transformers Efficient from A Token Sparsification View

image.png

CVPR 2023

主要内容:

  • 基于 token sparsification 的方式, 通过提取 semantic token 替代 image token, 缩小了 Vision Transformer 的计算复杂度, 同时较好的维持了原有的性能.
  • 提出的结构同时适用于全局注意力和局部注意力两种 Vision Transformer 形式.
  • 基于分辨率恢复模块, 可以在压缩计算复杂的同时, 维持原有的特征分辨率不变, 便于下游任务的迁移.

现有方法的问题

与 token 个数呈平方的计算复杂度限制了 Vision Transformer 的实际应用. 众多优化策略中, 缩减 token 数量是最直接的一种. 最近的一些工作中说明了图像 token 中存在大量冗余, 并给出了按照预先定义的评分机制来过滤不重要 token 的方法. 然而, 这些方法面临以下挑战.

  • 首先, 用于过滤的预先定义的评分机制通常是不精确的. 不同的层具有不同的值分布, 使用这些不精确的分数进行过滤会导致性能不理想.
  • 其次, 剩余的 token 在空间上不再均匀分布, 使得它们难以在局部 vision transformer 中使用.
  • 最后, 大规模的 token 剪枝会极大地破坏空间结构和位置信息, 在应用于下游任务时造成困难, 这些方法中没有提出解决方案.

token 稀疏化

这类方法主要可分为硬和软剪枝 (pruning) 两种.

  • 硬剪枝方法根据预定义的评分机制过滤掉一些不重要的 token. 然而这很难实现精确的评分. 因此, 它们通常会出现精度显着下降的问题.
    • DynamicViT、SPViT 和 AdaViT 引入了额外的预测网络来对 token 进行评分.
    • Evo-ViT、ATS 和 EViT 利用类 token 的值来评估 token 的重要性.
  • 软剪枝方法通过导入额外的注意力网络从图像 token 生成新 token. 但是现有方案性能损失仍旧严重.
    • TokenLearner 也主张用一些 token 来代替图像 token.

除了性能下降, 现有的方法还有以下缺点:

  • 首先, 是否或如何将这些方法扩展到局部 ViT 仍未得到探索.
  • 其次, 还没有讨论在 token 被修剪后如何为下游任务服务.

本文所提方法使用现成的 transformer 层来减少 token 数量. 现有的一些方法采用类似的方法来实现有效的非局部关系. 但本文方法与它们的不同之处在于:

  • 所提方法提取局部语义信息, 而不是非局部关系;
  • 语义 token 是少量的聚类中心, 可以替代海量的图像 token, 实现图像分类;
  • 我们的方法是专门用于修剪 token 的.

所提方法的思路

image.png

为了解决这些问题, 本文提出了 STViT, 可同时用于全局和局部 ViT, 同时修改版本也适用于下游任务. 所提出的方法基于以下观察:

  1. 与学习图像空间结构的局部 CNN 不同, ViT 将特征图离散化为用于全局特征探索 token, 缓解了维持整个图像结构和信息的需求;
  2. 离散 token 更有利于优化 [Scaled ReLU Matters for Training Vision Transformers];
  3. 在图 1 中, 右侧显示了不同 transformer 层中的注意力图, 深层中只有几条垂直线, 这意味着只有少数具有全局语义信息的 token 很重要.

因此, 我们认为没有必要为 ViT 维护大量结构化 token, 尤其是在深层. 使用一些具有高级语义信息的离散 token 可能会获得高的性能和效率.
在 STViT 中, 语义 token 代表聚类中心, 它们的数量远少于原始图像 token, 显着降低了计算成本. 受多头注意力可以进行聚类中心恢复(证明件论文附录 6.6) 这一事实的启发, 仅使用现成自注意力来生成语义 token.

所提模型的细节

image.png

STGM

前几个 transformer 层保持不变以获得具有低级特征的图像 token. 然后将图像 token 输入 STViT 的核心模块语义 token 生成模块 (STGM).
该模块由至少两个 transformer 层组成, 以生成语义 token, 在每个自注意力层中, 语义 token 作为 Q 输入, 图像 token 作为 K 和 V 输入. 语义 token 通过注意力层动态聚合图像 token 以获得 (恢复) 聚类中心.
在第一个注意层中, 语义 token 由窗口内和窗口间的空间组合池化策略进行初始化. 这考虑到在每个窗口中合并语义信息并最大化相邻窗口之间的距离. 并且由于这种空间初始化, 语义 token 主要结合局部语义信息并在空间中实现离散和均匀分布.

image.png

窗口内和窗口间的空间组合池化策略过程较为琐碎, 每个聚类的中心 P i = s o f t m a x ( M i + O i ) ⋅ X w P_i=softmax(M_i + O_i) \cdot X_w Pi=softmax(Mi+Oi)Xw 获得. 这里涉及到两个分别对应于窗口内空间注意力 logits M i M_i Mi 和跨窗口空间注意力 logits O i O_i Oi, 这里引入了两个超轻量的结构来进行特征变换. 后者的生成依赖于前者的结果. 整体流程如下:

  1. 给定原始特征图 X ∈ R H × W × C X \in \mathbb{R}^{H \times W \times C} XRH×W×C 可以划分成 N s = w s × w s N_s = w_s \times w_s Ns=ws×ws 个空间大小为 H w s × W w s \frac{H}{w_s} \times \frac{W}{w_s} wsH×wsW 的特征 window。每个 window 都会最终生成一个初始聚类中心。
  2. 计算 window 内的信息聚合注意力 logits M i ∈ R H w s × W w s = C o n v ( G e L U ( L N ( D e p t h C o n v ( X w i ) ) ) ) M_i \in \mathbb{R}^{\frac{H}{w_s} \times \frac{W}{w_s}} = Conv(GeLU(LN(DepthConv(X^i_w)))) MiRwsH×wsW=Conv(GeLU(LN(DepthConv(Xwi)))). 这里表示将第 i 个窗口内的 X 进行了变换,通道数量压缩为 1。
  3. 计算 window 间的信息聚合注意力 logits O i O_i Oi:
    1. 这里需要用到已经获得的 M i M_i Mi 来计算局部窗口集成 token P ^ i = s o f t m a x ( M i ) ⋅ X w i \hat{P}_i = softmax(M_i) \cdot X^i_w P^i=softmax(Mi)Xwi 按照空间排布构建的 2D tensor P ^ ∈ R w s × w s × C \hat{P} \in \mathbb{R}^{w_s \times w_s \times C} P^Rws×ws×C.
    2. 接下来计算 O i = C o n v ( G e L U ( L N ( D e p t h C o n v ( X w i ) ) ) ) O_i = Conv(GeLU(LN(DepthConv(X^i_w)))) Oi=Conv(GeLU(LN(DepthConv(Xwi)))), 这里将通道数量从 C 变为了 H W w s 2 \frac{HW}{w_s^2} ws2HW, 变形后即可得到最终的 O i ∈ R H w s × W w s O_i \in \mathbb{R}^{\frac{H}{w_s} \times \frac{W}{w_s}} OiRwsH×wsW.

在接下来的注意力层中, 除了进一步聚类外, 语义 token 还引入了通过高斯噪声初始化的全局聚类中心 G ∈ R N s × C G \in \mathbb{R}^{N_s \times C} GRNs×C, 网络从而可以自适应地选择部分语义 token 来关注全局语义信息. 作者在这里强调了这一设计与可学习的位置编码的差异, 这里的全局聚类中心直接被加到了 Q 上, 而没有被加到 K 中, 这与位置编码不同. 同时作者的试验里也展示了这样设计比实际的位置编码更好一些.

image.png

STGM 之后原始图像 token 被丢弃, 只为后续的 transformer 层保留语义 token. 因为语义 token 的生成是灵活的并且是空间感知的, 所提提出的方法可以插入到全局和局部 ViT. 对于局部 ViT, 每个窗口中独立生成数个语义 token.

在局部 ViT 中的应用

不同于全局 ViT, 局部 ViT 中多了局部窗口的概念. 提出的 STGM 自身的划分子窗口初始化聚类中心的操作被局限在各个局部窗口中. 所以对于具有 N w = w × w N_w = w \times w Nw=w×w 个窗口的局部 ViT 而言, 其实际对应的 semantic token 有 w s × w s × N w w_s \times w_s \times N_w ws×ws×Nw 个.
虽然初始聚类中心来自 w × w w \times w w×w 个窗口, 但使用大小为 w k × w k w_k \times w_k wk×wk 的较大窗口来获取 K 和 V, 并且全局聚类中心的设计也可以缓解有限窗口大小的影响.
在局部 ViT 模型中, 每个局部 transformer 层通常后面跟着一个跨窗口连接层, 例如 Swin Transformer 上的局部 transformer 层之后的移位窗口 transformer 层. 在本文方法中, 注意力是在局部自注意力层的 w s × w s w_s \times w_s ws×ws 窗口内计算的, 跨窗口连接可以通过在更大尺寸 (例如 4 × w s 4 \times w_s 4×ws) 的滑动窗口中计算自注意力来实现, 因为每个窗口中的 token 数量较少. 对于低分辨率输入, 这里的跨窗口连接层相当于一个全局自注意力层.

适配下游任务迁移需求

值得注意的是 token 稀疏化策略会大量丢失空间信息, 之前所有的相关方案都没有讨论如何在下游任务中使用它们. 这实际上严重阻碍了它们的应用. 本文基于 STViT 提出了 STViT-R, 其采用恢复模块和哑铃单元来周期性恢复全分辨率特征图, 同时中间的 transformer 层继续使用语义 token 来节省计算成本. 这样的设计增强了所提方案对于下游任务的迁移能力.

  • 恢复模块: 使用原始图像 token X 作为 Q, 之前的 semantic token S 作为 K 和 V, 通过 transformer 层从而获得更新后的图像 token.
  • 哑铃单元: 在 STViT-R 中, transformer 层被重新组合成多个哑铃单元. 每个哑铃单元由四个部分组成. 第一部分中的转换器负责处理图像 token; 第二部分是语义 token 生成模块; 第三部分中的转换器层处理语义 token; 最后一部分是恢复模块. 通过重复多个哑铃单元, 网络将保留详细的空间信息, 这不仅可以增强分类, 还可以服务于下游任务.

实验结果

image.png

image.png

image.png

image.pngimage.png

表 7 中, 通过向 K 中添加全局初始聚类中心来充当真实的可学习位置编码.

  • 表 8 中的实验限制了总体层数不变, 只是调整了各个单元的之间的层数分配.
  • 表 9 中的试验限制了所有模型 FLOPs 与完整模型一致.

image.png

这里也尝试将位置编码应用于语义 token 中. 上表显示了不同位置编码方法的比较, 包括学习位置编码、条件位置编码和相对位置编码. 尽管相对位置编码将 Swin-T 提高了 1.2%, 但所有位置编码方法都不适用于 DeiT-S 和 Swin-T. 这些实验表明, 提出的语义 token 之间的交互依赖于高级语义信息, 几乎不使用位置关系.

image.png

这里测试了三种额外的空间池化策略来固定获得 25 个 semantic token, 但是提出的方案性能最好:

  • 大核和重叠空间池化
  • 多尺度空间池化
  • 自适应空间池化

核心代码

代码可见https://github.com/changsn/STViT-R/blob/main/models/swin_transformer.py, 内容很乱, 需要仔细梳理. 另外从issue中可以了解, 释放出来的代码也并不完整, 仅可以当做一个理解方法的原型.

四阶段哑铃单元

对应于前述内容提到的四阶段的结构:

for i, blk in enumerate(self.blocks):
    # 哑铃单元 1
    if i == 0:
        x = blk(x)
    elif i == 1:
        semantic_token = blk(x)
    elif i == 2:
        if self.use_global:
            semantic_token = blk(semantic_token+self.semantic_token2, torch.cat([semantic_token, x], dim=1))
        else:
            semantic_token = blk(semantic_token, torch.cat([semantic_token, x], dim=1))
    elif i > 2 and i < 5:
        semantic_token = blk(semantic_token)
    elif i == 5:
        x = blk(x, semantic_token)

    # 哑铃单元 2
    elif i == 6:
        x = blk(x)
    elif i == 7:
        semantic_token = blk(x)
    elif i == 8:
        semantic_token = blk(semantic_token, torch.cat([semantic_token, x], dim=1))
    elif i > 8 and i < 11:
        semantic_token = blk(semantic_token)
    elif i == 11:
        x = blk(x, semantic_token)

    # 哑铃单元 3
    elif i == 12:
        x = blk(x)
    elif i == 13:
        semantic_token = blk(x)
    elif i == 14:
        semantic_token = blk(semantic_token, torch.cat([semantic_token, x], dim=1))
    elif i > 14 and i < 17:
        semantic_token = blk(semantic_token)
    else:
        x = blk(x, semantic_token)

单取其中一个:

if i == 0: # SwinTransformerBlock with 0 shift
    x = blk(x)
elif i == 1: # SemanticAttentionBlock 对应STGM的第一个transformer block
    semantic_token = blk(x)
elif i == 2: # Block(Global Attention)对应STGM的第二个transformer block
    if self.use_global:
        semantic_token = blk(semantic_token + self.semantic_token2,
                             torch.cat([semantic_token, x], dim=1))
    else:
        semantic_token = blk(semantic_token,
                             torch.cat([semantic_token, x], dim=1))
elif i > 2 and i < 5: # Local Attention(3) -> Global Attention(4)
    semantic_token = blk(semantic_token)
elif i == 5: # Global Attention
    x = blk(x, semantic_token)

这里引入了 semantic_token2 作为全局初始聚类中心, 加到原始信息流中的semantic token:

self.use_global = use_global
if self.use_global:
    self.semantic_token2 = nn.Parameter(torch.zeros(1, self.num_samples, embed_dim))
    trunc_normal_(self.semantic_token2, std=.02)

可以看到, 这里实际上在单个哑铃单元中执行了这样的过程:

  1. X=Local-MHSA(X)
  2. S=SemanticAttentionBlock(X)
  3. S=Global-MHSA(S, [S,X])
  4. S=Local-MHSA(S)
  5. S=Global-MHSA(S)
  6. X=Global-MHSA(X, S)

SemanticAttentionBlock

这里应该就是前文所述的STGM的第一个transformer block了, 但是细节上与论文中的表述有所不同. 这里的K和V同时利用了原始图像token和池化后初始构建的semantic token.

class SemanticAttentionBlock(nn.Module):

    def __init__(self, dim, num_heads, multi_scale, window_size=7, sample_window_size=3, mlp_ratio=4., qkv_bias=False, drop=0.,
                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale_init_value=1e-5,
                 use_conv_pos=False, shortcut=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.multi_scale = multi_scale(sample_window_size)
        self.attn = Attention(dim, num_heads=num_heads, window_size=None, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.layer_scale_1 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.use_conv_pos = use_conv_pos
        if self.use_conv_pos:
            self.conv_pos = PosCNN(dim, dim)
        self.shortcut = shortcut
        self.window_size = window_size
        self.sample_window_size = sample_window_size

    def forward(self, x, y=None):
        B, L, C = x.shape
        H = W = int(math.sqrt(L))
        x = x.view(B, H, W, C)
        if y == None:
            xx = x.reshape(B, H // self.window_size, self.window_size, W // self.window_size, self.window_size, C)
            windows = xx.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, self.window_size, self.window_size, C).permute(0, 3, 1, 2)
            shortcut = self.multi_scale(windows)  # B*nW, W*W, C
            if self.use_conv_pos:
                shortcut = self.conv_pos(shortcut)
            pool_x = self.norm1(shortcut.reshape(B, -1, C)).reshape(-1, self.multi_scale.num_samples, C)
        else:
            B, L_, C = y.shape
            H_ = W_ = int(math.sqrt(L_))
            y = y.reshape(B, H_ // self.sample_window_size, self.sample_window_size, W_ // self.sample_window_size, self.sample_window_size, C)
            y = y.permute(0, 1, 3, 2, 4, 5).reshape(-1, self.sample_window_size*self.sample_window_size, C)
            shortcut = y
            if self.use_conv_pos:
                shortcut = self.conv_pos(shortcut)
            pool_x = self.norm1(shortcut.reshape(B, -1, C)).reshape(-1, self.multi_scale.num_samples, C)

        # produce K, V
        k_windows = F.unfold(x.permute(0, 3, 1, 2), kernel_size=10, stride=4).view(B, C, 10, 10, -1).permute(0, 4, 2, 3, 1)
        k_windows = k_windows.reshape(-1, 100, C)
        k_windows = torch.cat([shortcut, k_windows], dim=1)
        k_windows = self.norm1(k_windows.reshape(B, -1, C)).reshape(-1, 100+self.multi_scale.num_samples, C)

        if self.shortcut:
            x = shortcut + self.drop_path(self.layer_scale_1 * self.attn(pool_x, k_windows))
        else:
            x = self.layer_scale_1 * self.attn(pool_x, k_windows)
        x = x.view(B, H // self.window_size, W // self.window_size, self.sample_window_size, self.sample_window_size, C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, C)
        x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x)))
        return x

猜你喜欢

转载自blog.csdn.net/P_LarT/article/details/131226411