CVPR 2023 | Making Vision Transformers Efficient from A Token Sparsification View

CVPR 2023 | Making Vision Transformers Efficient from A Token Sparsification View

image.png

CVPR 2023

main content:

  • Based on token sparsification, by extracting semantic tokens instead of image tokens, the computational complexity of Vision Transformer is reduced while maintaining the original performance.
  • The proposed structure is applicable to both Vision Transformer forms of global attention and local attention.
  • Based on the resolution recovery module, it can maintain the original feature resolution while compressing and computing complex, which is convenient for the migration of downstream tasks.

Problems with Existing Methods

The computational complexity squared with the number of tokens limits the practical application of Vision Transformer. Among many optimization strategies, reducing the number of tokens is the most direct one. Some recent works have shown that there is a lot of redundancy in image tokens, and give There are methods to filter unimportant tokens according to a predefined scoring mechanism. However, these methods face the following challenges.

  • First, the pre-defined scoring mechanism for filtering is usually imprecise. Different layers have different value distributions, and filtering with these imprecise scores can lead to suboptimal performance.
  • Second, the remaining tokens are no longer spatially uniform, making them difficult to use in local vision transformers.
  • Finally, large-scale token pruning will greatly destroy the spatial structure and location information, causing difficulties when applied to downstream tasks, and no solution is proposed in these methods.

token sparse

These methods can be mainly divided into hard and soft pruning.

  • Hard pruning methods filter out some unimportant tokens according to a predefined scoring mechanism. However, it is difficult to achieve accurate scoring. Therefore, they usually suffer from a significant drop in accuracy.
    • DynamicViT, SPViT and AdaViT introduce additional predictive networks to score tokens.
    • Evo-ViT, ATS, and EViT use token-like values ​​to evaluate the importance of tokens.
  • The soft pruning method generates new tokens from image tokens by introducing an additional attention network. However, the performance loss of existing schemes is still serious.
    • TokenLearner also advocates replacing image tokens with some tokens.

In addition to performance degradation, existing methods have the following disadvantages:

  • First, whether or how to extend these methods to local ViT remains unexplored.
  • Second, there is no discussion on how to serve downstream tasks after tokens are pruned.

The method proposed in this paper uses an off-the-shelf transformer layer to reduce the number of tokens. Some existing methods use similar methods to achieve effective non-local relations. But the method in this paper differs from them in that:

  • The proposed method extracts local semantic information, rather than non-local relations;
  • Semantic tokens are a small number of cluster centers, which can replace massive image tokens to achieve image classification;
  • Our method is dedicated to pruning tokens.

The idea of ​​the proposed method

image.png

To address these issues, this paper proposes STViT, which can be used for both global and local ViT, while the modified version is also suitable for downstream tasks. The proposed method is based on the following observations:

  1. Different from the local CNN that learns the spatial structure of the image, ViT discretizes the feature map into a token for global feature exploration, which alleviates the need to maintain the entire image structure and information;
  2. Discrete token is more conducive to optimization [Scaled ReLU Matters for Training Vision Transformers];
  3. In Figure 1, the right side shows the attention maps in different transformer layers, and there are only a few vertical lines in the deep layers, which means that only a few tokens with global semantic information are important.

Therefore, we think it is not necessary to maintain a large number of structured tokens for ViT, especially in deep layers. Using some discrete tokens with high-level semantic information may achieve high performance and efficiency. In STViT, semantic tokens represent cluster centers, and
their The number is much less than the original image token, which significantly reduces the computational cost. Inspired by the fact that multi-head attention can perform cluster center recovery (appendix 6.6 of the proof paper), only ready-made self-attention is used to generate semantic tokens.

Details of the proposed model

image.png

STGM

The first few transformer layers remain unchanged to obtain image tokens with low-level features. Then the image tokens are input into STViT's core module Semantic Token Generation Module (STGM). This module consists of at least two
transformer layers to generate semantic tokens, in In each self-attention layer, the semantic token is used as Q input, and the image token is used as K and V input. The semantic token dynamically aggregates image tokens through the attention layer to obtain (recover) cluster centers. In the first attention layer, the
semantic token The tokens are initialized by a spatially combined pooling strategy within and between windows . This considers incorporating semantic information in each window and maximizing the distance between adjacent windows. And due to this spatial initialization, semantic tokens mainly combine local Semantic information and achieve discrete and uniform distribution in space.

image.png

The spatial combination pooling strategy process within and between windows is relatively trivial, and the center of each cluster P i = softmax ( M i + O i ) ⋅ X w P_i=softmax(M_i + O_i) \cdot X_wPi=softmax(Mi+Oi)XwObtained. This involves two spatial attention logits M i M_i corresponding to the window spaceMiand cross-window spatial attention logits O i O_iOi, here introduces two ultra-lightweight structures for feature transformation. The generation of the latter depends on the result of the former. The overall process is as follows:

  1. Given the original feature map X ∈ RH × W × CX \in \mathbb{R}^{H \times W \times C}XRH × W × C can be divided intoN s = ws × ws N_s = w_s \times w_sNs=ws×wsThe space size is H ws × W ws \frac{H}{w_s} \times \frac{W}{w_s}wsH×wsWThe characteristic window. Each window will eventually generate an initial cluster center.
  2. Calculate the information aggregation attention logits in the window M i ∈ RH ws × W ws = C onv ( G e LU ( LN ( D epth C onv ( X wi ) ) ) ) M_i \in \mathbb{R}^{\frac {H}{w_s} \times \frac{W}{w_s}} = Conv(GeLU(LN(DepthConv(X^i_w))))MiRwsH×wsW=Conv(GeLU(LN(DepthConv(Xwi)))) . Here, the X in the i-th window is transformed, and the number of channels is compressed to 1.
  3. Calculate the information aggregation attention logits O i O_i between windowsOi:
    1. Here you need to use the obtained M i M_iMiTo calculate the local window integration token P ^ i = softmax ( M i ) ⋅ X wi \hat{P}_i = softmax(M_i) \cdot X^i_wP^i=softmax(Mi)Xwi2D tensor P ^ ∈ R ws × ws × C \hat{P} \in \mathbb{R}^{w_s \times w_s \times C} constructed according to the spatial arrangementP^Rws×ws×C.
    2. Subsequent computation O i = C onv ( G e LU ( LN ( D epth C onv ( X wi ) ) ) ) O_i = Conv(GeLU(LN(DepthConv(X^i_w))))Oi=Conv(GeLU(LN(DepthConv(Xwi)))) , here change the number of channels from C toHW ws 2 \frac{HW}{w_s^2}ws2HW, the final O i ∈ RH ws × W ws O_i \in \mathbb{R}^{\frac{H}{w_s} \times \frac{W}{w_s}} can be obtained after deformationOiRwsH×wsW.

In the next attention layer, in addition to further clustering, the semantic token also introduces the global clustering center G ∈ RN s × CG \in \mathbb{R}^{N_s \times C} initialized by Gaussian noiseGRNs× C , the network can adaptively select some semantic tokens to focus on the global semantic information. The author here emphasizes the difference between this design and the learnable position encoding, where the global cluster center is directly added to Q, while is not added to K, which is different from the positional encoding. At the same time, the author's experiment also shows that this design is better than the actual positional encoding.

image.png

The original image tokens are discarded after STGM, and only the semantic tokens are reserved for subsequent transformer layers. Since the generation of semantic tokens is flexible and space-aware, the proposed method can be inserted into both global and local ViT. For local ViT, each Generate several semantic tokens independently in each window.

Application in local ViT

Different from the global ViT, the concept of local windows is added to the local ViT. The operation of dividing the sub-windows to initialize the clustering centers of the proposed STGM is limited to each local window. So for a system with N w = w × w N_w = w \ times wNw=w×For the local ViT of w windows, the actual corresponding semantic token isws × ws × N w w_s \times w_s \times N_wws×ws×Nw.
Although the initial cluster center comes from w × ww \times ww×w windows, but use sizewk × wk w_k \times w_kwk×wkLarger window to obtain K and V, and the design of the global clustering center can also alleviate the impact of the limited window size.
In the local ViT model, each local transformer layer is usually followed by a cross-window connection layer, such as on the Swin Transformer The shift window transformer layer after the local transformer layer of . In this paper, the attention is ws × ws w_s \times w_s in the local self-attention layerws×wsCalculated within the window, the cross-window connection can be done in a larger size (such as 4 × ws 4 \times w_s4×ws) to calculate self-attention in the sliding window, because the number of tokens in each window is small. For low-resolution input, the cross-window connection layer here is equivalent to a global self-attention layer.

Adapt to downstream task migration requirements

It is worth noting that the token sparsification strategy will lose a lot of spatial information, and all previous related schemes have not discussed how to use them in downstream tasks. This actually seriously hinders their application. This paper proposes STViT-R based on STViT, its The restoration module and dumbbell unit are used to periodically restore the full-resolution feature map, while the intermediate transformer layer continues to use semantic tokens to save computing costs. This design enhances the transferability of the proposed scheme for downstream tasks.

  • Restoration module: use the original image token X as Q, the previous semantic token S as K and V, and pass through the transformer layer to obtain the updated image token.
  • Dumbbell unit: In STViT-R, the transformer layer is recombined into multiple dumbbell units. Each dumbbell unit consists of four parts. The transformer in the first part is responsible for processing the image token; the second part is the semantic token generation module; The converter layer in the third part deals with semantic tokens; the last part is the recovery module. By repeating multiple dumbbell units, the network will retain detailed spatial information, which can not only enhance classification, but also serve downstream tasks.

Experimental results

image.png

image.png

image.png

image.pngimage.png

In Table 7, it acts as a true learnable positional encoding by adding global initial cluster centers to K.

  • The experiments in Table 8 limit the overall number of layers to be constant, but only adjust the layer number distribution between each unit.
  • The experiments in Table 9 constrained all model FLOPs to be consistent with the full model.

image.png

Here we also try to apply position encoding to semantic tokens. The above table shows the comparison of different position encoding methods, including learning position encoding, conditional position encoding and relative position encoding. Although relative position encoding improves Swin-T by 1.2%, but All positional encoding methods are unsuitable for DeiT-S and Swin-T. These experiments show that the interaction between the proposed semantic tokens relies on high-level semantic information and hardly uses positional relations.

image.png

Three additional spatial pooling strategies are tested here to obtain a fixed 25 semantic tokens, but the proposed scheme has the best performance:

  • Large kernels and overlapping spatial pooling
  • Multi-scale spatial pooling
  • Adaptive Spatial Pooling

core code

The code can be seen at https://github.com/changsn/STViT-R/blob/main/models/swin_transformer.py , the content is very messy and needs to be sorted out carefully. In addition, we can understand from the issue that the released code is not complete. It can only be used as a prototype to understand the method.

Four-Stage Dumbbell Unit

Corresponding to the four-stage structure mentioned above:

for i, blk in enumerate(self.blocks):
    # 哑铃单元 1
    if i == 0:
        x = blk(x)
    elif i == 1:
        semantic_token = blk(x)
    elif i == 2:
        if self.use_global:
            semantic_token = blk(semantic_token+self.semantic_token2, torch.cat([semantic_token, x], dim=1))
        else:
            semantic_token = blk(semantic_token, torch.cat([semantic_token, x], dim=1))
    elif i > 2 and i < 5:
        semantic_token = blk(semantic_token)
    elif i == 5:
        x = blk(x, semantic_token)

    # 哑铃单元 2
    elif i == 6:
        x = blk(x)
    elif i == 7:
        semantic_token = blk(x)
    elif i == 8:
        semantic_token = blk(semantic_token, torch.cat([semantic_token, x], dim=1))
    elif i > 8 and i < 11:
        semantic_token = blk(semantic_token)
    elif i == 11:
        x = blk(x, semantic_token)

    # 哑铃单元 3
    elif i == 12:
        x = blk(x)
    elif i == 13:
        semantic_token = blk(x)
    elif i == 14:
        semantic_token = blk(semantic_token, torch.cat([semantic_token, x], dim=1))
    elif i > 14 and i < 17:
        semantic_token = blk(semantic_token)
    else:
        x = blk(x, semantic_token)

Take one of them alone:

if i == 0: # SwinTransformerBlock with 0 shift
    x = blk(x)
elif i == 1: # SemanticAttentionBlock 对应STGM的第一个transformer block
    semantic_token = blk(x)
elif i == 2: # Block(Global Attention)对应STGM的第二个transformer block
    if self.use_global:
        semantic_token = blk(semantic_token + self.semantic_token2,
                             torch.cat([semantic_token, x], dim=1))
    else:
        semantic_token = blk(semantic_token,
                             torch.cat([semantic_token, x], dim=1))
elif i > 2 and i < 5: # Local Attention(3) -> Global Attention(4)
    semantic_token = blk(semantic_token)
elif i == 5: # Global Attention
    x = blk(x, semantic_token)

Here, the semantic token added to the original information flow semantic_token2as :

self.use_global = use_global
if self.use_global:
    self.semantic_token2 = nn.Parameter(torch.zeros(1, self.num_samples, embed_dim))
    trunc_normal_(self.semantic_token2, std=.02)

As can be seen, such a process is actually performed here in a single dumbbell unit:

  1. X=Local-MHSA(X)
  2. S=SemanticAttentionBlock(X)
  3. S=Global-MHSA(S, [S,X])
  4. S=Local-MHSA(S)
  5. S=Global-MHSA(S)
  6. X=Global-MHSA(X, S)

SemanticAttentionBlock

This should be the first transformer block of the STGM mentioned above, but the details are different from the expressions in the paper. K and V here use both the original image token and the semantic token initially constructed after pooling.

class SemanticAttentionBlock(nn.Module):

    def __init__(self, dim, num_heads, multi_scale, window_size=7, sample_window_size=3, mlp_ratio=4., qkv_bias=False, drop=0.,
                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale_init_value=1e-5,
                 use_conv_pos=False, shortcut=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.multi_scale = multi_scale(sample_window_size)
        self.attn = Attention(dim, num_heads=num_heads, window_size=None, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.layer_scale_1 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.use_conv_pos = use_conv_pos
        if self.use_conv_pos:
            self.conv_pos = PosCNN(dim, dim)
        self.shortcut = shortcut
        self.window_size = window_size
        self.sample_window_size = sample_window_size

    def forward(self, x, y=None):
        B, L, C = x.shape
        H = W = int(math.sqrt(L))
        x = x.view(B, H, W, C)
        if y == None:
            xx = x.reshape(B, H // self.window_size, self.window_size, W // self.window_size, self.window_size, C)
            windows = xx.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, self.window_size, self.window_size, C).permute(0, 3, 1, 2)
            shortcut = self.multi_scale(windows)  # B*nW, W*W, C
            if self.use_conv_pos:
                shortcut = self.conv_pos(shortcut)
            pool_x = self.norm1(shortcut.reshape(B, -1, C)).reshape(-1, self.multi_scale.num_samples, C)
        else:
            B, L_, C = y.shape
            H_ = W_ = int(math.sqrt(L_))
            y = y.reshape(B, H_ // self.sample_window_size, self.sample_window_size, W_ // self.sample_window_size, self.sample_window_size, C)
            y = y.permute(0, 1, 3, 2, 4, 5).reshape(-1, self.sample_window_size*self.sample_window_size, C)
            shortcut = y
            if self.use_conv_pos:
                shortcut = self.conv_pos(shortcut)
            pool_x = self.norm1(shortcut.reshape(B, -1, C)).reshape(-1, self.multi_scale.num_samples, C)

        # produce K, V
        k_windows = F.unfold(x.permute(0, 3, 1, 2), kernel_size=10, stride=4).view(B, C, 10, 10, -1).permute(0, 4, 2, 3, 1)
        k_windows = k_windows.reshape(-1, 100, C)
        k_windows = torch.cat([shortcut, k_windows], dim=1)
        k_windows = self.norm1(k_windows.reshape(B, -1, C)).reshape(-1, 100+self.multi_scale.num_samples, C)

        if self.shortcut:
            x = shortcut + self.drop_path(self.layer_scale_1 * self.attn(pool_x, k_windows))
        else:
            x = self.layer_scale_1 * self.attn(pool_x, k_windows)
        x = x.view(B, H // self.window_size, W // self.window_size, self.sample_window_size, self.sample_window_size, C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, -1, C)
        x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x)))
        return x

Guess you like

Origin blog.csdn.net/P_LarT/article/details/131226411