[Image task] Transformer series.3

This article introduces 3 works on improving Transformer to achieve different image tasks: Few-sample medical image segmentation CAT-Net (arXiv2023), efficient image reconstruction and other tasks GRL (CVPR2023), local information thinking in lightweight visual Transformer CloFormer (arXiv2023).

Few Shot Medical Image Segmentation with Cross Attention Transformer, arXiv2023

Interpretation: 2023 Hong Kong University of Science and Technology's new work | Novel attention mechanism effectively improves the semantic segmentation accuracy of small samples of medical images!

Paper: https://arxiv.org/abs/2303.13867

Code: not yet open source

introduce

In the field of deep learning medical image segmentation, training a model with strong performance and large-scale deployment often requires a large amount of manually labeled data for supervised training, and the cost is very high. To address this challenge, few-shot learning (few-shot) techniques have the potential to learn new categories from a limited number of samples.

Most few-shotsegmentation methods are learning how to learn (aim to learn a meta-learner), supportpredicting the segmentation of an image based on the knowledge of the image and its corresponding segmentation labels query. The core is: how to effectively transfer knowledge from supportimage to queryimage. Existing few-shot segmentation methods mainly focus on the following two aspects:

  1. How to learn a meta-learner

  2. How to better transfer knowledge from supportimage to queryimage

Although prototype-based methods have achieved good results, they usually ignore the interaction between the training process supportand queryfeatures.

Therefore, this paper proposes a CAT-Netnew network structure called Cross-Attention Transformer, which can better capture the correlation between supportimages and images, promote the interaction between features and features, and reduce useless pixel information at the same time. Improve feature expression ability and segmentation performance; In addition, this paper also proposes an iterative training framework to feed back previous segmentation results into the attention Transformer to effectively enhance and refine features and segmentation results.querysupportquerysupport

CAT-Net network

 CAT-NetThe network framework diagram mainly consists of three parts:

  1. Feature extraction sub-network with mask MIFEfor extracting initial queryand supportfeature andquery mask;
  2. A cross-mask attention Transformermodule CMAT, where querythe and supportfeatures boost each other to improve queryprediction accuracy;
  3. The framework is iteratively refined, modules are applied sequentially CMATto continuously improve segmentation performance, and the entire framework is trained in an end-to-end manner. 

Mask Incorporated Feature Extraction(MIFE)

The MIFE subnetwork receives query and support images as input, generates their respective features, and support masks. Then, a simple classifier is used to predict the segmentation result of the query image. as the picture shows,

  1. First, a feature extractor (ie, ResNet-50) is used to map the query and support image pairs Iq and Is into the feature space, and generate the multi-layer feature map Fq of the query image and the feature map Fs of the support image, respectively.
  2. The support mask is pooled with Fs, and after expansion, it is spliced ​​with Fq and Fs.
  3. A prior mask is further concatenated with query features, and the correlation between query and support features is enhanced through a pixel-level similarity map.
  4. Use a simple classifier to process query features to get query mask.

Cross Masked Attention Transformer (CMAT)

The CMAT module consists of three main components: a self-attention module, a cross-mask attention module, and a prototype segmentation module. in,

  • The self-attention module is used to extract query features and support global information in support features;
  • A cross mask attention module is used to remove redundant background information while passing on foreground information;
  • The prototype segmentation module is used to generate the final prediction results for the query image.

Iterative Refinement framework

This module is designed to optimize query and support features as well as query segmentation masks. Therefore, the refined segmentation can be carried out through the idea of ​​iterative optimization, and the result after the i-th iteration is given by the following formula:

The breakdown of each step can be expressed as follows:

Where CMA( ) represents the self-attention and cross-mask attention modules, and Proto( ) represents the prototype segmentation module, which means that the enhanced features and optimized segmentation results are obtained by applying the CMA and Proto modules iteratively for multiple times.

experiment

Ablation experiment

Table 2 verifies the effectiveness of each component in the network: S→Q and Q→S represent a branch in CAT-Net for enhancing support or query features, while S↔Q represents applying cross-attention to S and Q.

Table3 The impact of using CMAT modules under different iterations. Increasing the number of modules can improve performance. When using 5 modules, the Dice coefficient is increased by 2.26%. With 4 CMAT modules, there is a balance between efficiency and performance. 

Efficient and Explicit Modelling of Image Hierarchies for Image Restoration , CVPR2023

Interpretation: CVPR'2023 Plug and Play Series! | A Lightweight and Efficient Self-Attention Mechanism Helps Image Restoration Network Win SOTA! (qq.com)

Paper: https://arxiv.org/abs/2303.00748

Code: https://github.com/ofsoundof/GRL-Image-Restoration.git

introduce

Image restoration aims to recover high-quality images from low-quality images , usually due to image degradation processes such as blurring, downsampling, noise introduction, and JPEGcompression. Because during image degradation, its important content information is missing, image restoration is a challenging inverse process. Therefore, in order to restore high-quality images, the rich information presented in degraded images should be fully utilized. ​​​​

Figure 1. Local features (edges, colors) and regional features (pink box) can be well modeled by convolutional neural network (CNN) and window self-attention. But in contrast, global features (cyan rectangles) are difficult to effectively and explicitly model features.

Natural images contain a series of features at global, regional and local scales, which can be used by deep neural networks for image restoration.  Local features are usually some edge and color features. Since they only span a few pixels, they can be 3 x 3modeled and captured using small convolution kernels (for example) ; for regional features, they usually span tens of pixels. This window area feature Usually it can cover some small objects and a certain part of large objects (such as the pink box in Figure 1 above). Due to the larger range of regional features, you can choose to use a large convolution kernel for modeling, but the amount of parameters and calculations are too high. Large and inefficient, so a Transformer with a window attention mechanism would be a better choice ; in addition to local and regional features, some features have a global span (cyan rectangle in Figure 1): mainly reflected in symmetry And multi-scale pattern repeatability (Figure 1a), texture similarity at the same scale (Figure 1b), and large object content structure similarity and consistency (Figure 1c), in order to model and deal with this range of features, the network needs to have a global Ability to understand images.

Local and region-wide features can be well modeled and captured, but there are two main challenges in modeling global features:

  • First, existing convolutional and window attention based image restoration networks cannot explicitly capture long-distance dependencies by using a single computation module, so global image understanding is mainly achieved by gradually propagating features through repeated computation modules.

  • Second, as the resolution of images continues to increase, long-range dependency modeling faces the challenge of computational burden.

The above discussion leads to a series of research questions:

  • How to efficiently model global-scale features in high-dimensional images for image restoration?

  • How to explicitly model image hierarchy information (local, regional, global) by a single computational module for high-dimensional image restoration?

  • How can this joint modeling lead to uniform performance improvements across different image restoration tasks?

To this end, this paper focuses on the above three research questions and proposes solutions one by one:

  • First, this paper proposes an anchor-based striped self-attention mechanism for global scope dependency modeling;
  • Second, a new Transformernetwork is proposed GRLfor explicitly modeling global, regional, and local-scale dependencies in a single computational module;
  • Finally, the proposed GRLnetwork performs well in seven categories of image restoration tasks (image super-resolution, denoising, JPEGcompression artifact removal, demosaicing, real image super-resolution, single-image motion deblurring, and defocus deblurring) SOTA!

GRL network

 Figure (a) above shows the proposed GRL network architecture diagram, which consists of multiple Transformer Layercomponents. The above figure (b) shows Transformer Layerthe calculation module, which consists of three sub-modules and is used to model global, regional and local image structure features, where the anchor-based stripe self-attention mechanism is Anchored Stripe Attentionused to model global image structure features, based on window The self-attention mechanism Window Attention V2is used to model regional features, and two concatenated 3 x 3convolutions followed by a channel attention Channel Attentioncan be used to model efficient local features.  The above figure (c) shows the structural diagram of the anchor-based striped self-attention mechanism, which can help the network capture image structure features beyond the region (global).

Anchored Stripe Attention

TransformerAlthough the self-attention mechanism architecture can well model long-distance dependencies and capture global feature information, the tokenslarge number of images leads to a huge amount of calculation. In order to reduce the computational complexity, self-attention can be performed in the window area, but this type of window-based self-attention mechanism is limited by the window size and can only capture context feature information based on the window area. So this leads to a question: How to model features beyond the window area with low calculations?

The above pictures (a) and (b) are the same picture from two different resolutions, the blue pixels in (a) and the red pixels in (b) are taken from the same location. Figure (c) shows the attention map of blue pixels and other pixels; Figure (d) shows the attention map of red pixels and other pixels. Figure (c) and Figure (d) attention maps are very similar, which is what this paper calls cross-scale similarity .

 tokensThe effect of self-attention on large-resolution images (based on the principle of cross-scale similarity) is achieved by self-attention on images with small resolutions (small-resolution images have less ) , which greatly reduces the amount of computation , and can effectively model features beyond the scope of the window area (global features).

In order to further reduce the amount of calculation, the author discovered another important characteristic of natural images: the characteristics of natural images usually appear in a non-isotropic manner , as shown in the figure above, a single object in (c) and (d), (h ), the symmetry in (e) and (g), etc. Therefore, global-scale isotropic attention is redundant to capture non-isotropic image features . Based on this, the paper proposes a method for attention processing in non-isotropic stripes, and the attention mechanism includes four modes: horizontal stripes, vertical stripes, translation of horizontal stripes and translation of vertical stripes. The attention mechanism of horizontal and vertical stripes can Transformerbe used alternately in the network . Through this attention method, the complexity of global self-attention calculation can be reduced while maintaining the global scope modeling ability.

Therefore, combined with the concept of anchor points, an anchor stripe self-attention is proposed. For this attention mechanism, the introduced anchors are utilized for efficient self-attention computation within vertical and horizontal stripes.

experiment

see more papers 

 

key code

grl.py

# https://github.com/ofsoundof/GRL-Image-Restoration/blob/main/models/common/mixed_attn_block_efficient.py

class Attention(ABC, nn.Module):
    def __init__(self):
        super(Attention, self).__init__()

    def attn(self, q, k, v, attn_transform, table, index, mask, reshape=True):
        # q, k, v: # nW*B, H, wh*ww, dim
        # cosine attention map
        B_, _, H, head_dim = q.shape
        if self.euclidean_dist:
            # print("use euclidean distance")
            attn = torch.norm(q.unsqueeze(-2) - k.unsqueeze(-3), dim=-1)
        else:
            attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
        attn = attn_transform(attn, table, index, mask)
        # attention
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = attn @ v  # B_, H, N1, head_dim
        if reshape:
            x = x.transpose(1, 2).reshape(B_, -1, H * head_dim)
        # B_, N, C
        return x


class WindowAttention(Attention):
    r"""Window attention. QKV is the input to the forward method.
    Args:
        num_heads (int): Number of attention heads.
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
    """

    def __init__(
        self,
        input_resolution,
        window_size,
        num_heads,
        window_shift=False,
        attn_drop=0.0,
        pretrained_window_size=[0, 0],
        args=None,
    ):

        super(WindowAttention, self).__init__()
        self.input_resolution = input_resolution
        self.window_size = window_size
        self.pretrained_window_size = pretrained_window_size
        self.num_heads = num_heads
        self.shift_size = window_size[0] // 2 if window_shift else 0
        self.euclidean_dist = args.euclidean_dist

        self.attn_transform = AffineTransform(num_heads)
        self.attn_drop = nn.Dropout(attn_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, qkv, x_size, table, index, mask):
        """
        Args:
            qkv: input QKV features with shape of (B, L, 3C)
            x_size: use x_size to determine whether the relative positional bias table and index
            need to be regenerated.
        """
        H, W = x_size
        B, L, C = qkv.shape
        qkv = qkv.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            qkv = torch.roll(
                qkv, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
            )

        # partition windows
        qkv = window_partition(qkv, self.window_size)  # nW*B, wh, ww, C
        qkv = qkv.view(-1, prod(self.window_size), C)  # nW*B, wh*ww, C

        B_, N, _ = qkv.shape
        qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # nW*B, H, wh*ww, dim

        # attention
        x = self.attn(q, k, v, self.attn_transform, table, index, mask)

        # merge windows
        x = x.view(-1, *self.window_size, C // 3)
        x = window_reverse(x, self.window_size, x_size)  # B, H, W, C/3

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        x = x.view(B, L, C // 3)

        return x

    def extra_repr(self) -> str:
        return (
            f"window_size={self.window_size}, shift_size={self.shift_size}, "
            f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
        )

    def flops(self, N):
        pass


class AnchorStripeAttention(Attention):
    r"""Stripe attention
    Args:
        stripe_size (tuple[int]): The height and width of the stripe.
        num_heads (int): Number of attention heads.
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
    """

    def __init__(
        self,
        input_resolution,
        stripe_size,
        stripe_groups,
        stripe_shift,
        num_heads,
        attn_drop=0.0,
        pretrained_stripe_size=[0, 0],
        anchor_window_down_factor=1,
        args=None,
    ):

        super(AnchorStripeAttention, self).__init__()
        self.input_resolution = input_resolution
        self.stripe_size = stripe_size  # Wh, Ww
        self.stripe_groups = stripe_groups
        self.stripe_shift = stripe_shift
        self.num_heads = num_heads
        self.pretrained_stripe_size = pretrained_stripe_size
        self.anchor_window_down_factor = anchor_window_down_factor
        self.euclidean_dist = args.euclidean_dist

        self.attn_transform1 = AffineTransform(num_heads)
        self.attn_transform2 = AffineTransform(num_heads)

        self.attn_drop = nn.Dropout(attn_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(
        self, qkv, anchor, x_size, table, index_a2w, index_w2a, mask_a2w, mask_w2a
    ):
        """
        Args:
            qkv: input features with shape of (B, L, C)
            anchor:
            x_size: use stripe_size to determine whether the relative positional bias table and index
            need to be regenerated.
        """
        H, W = x_size
        B, L, C = qkv.shape
        qkv = qkv.view(B, H, W, C)

        stripe_size, shift_size = _get_stripe_info(
            self.stripe_size, self.stripe_groups, self.stripe_shift, x_size
        )
        anchor_stripe_size = [s // self.anchor_window_down_factor for s in stripe_size]
        anchor_shift_size = [s // self.anchor_window_down_factor for s in shift_size]
        # cyclic shift
        if self.stripe_shift:
            qkv = torch.roll(qkv, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
            anchor = torch.roll(
                anchor,
                shifts=(-anchor_shift_size[0], -anchor_shift_size[1]),
                dims=(1, 2),
            )

        # partition windows
        qkv = window_partition(qkv, stripe_size)  # nW*B, wh, ww, C
        qkv = qkv.view(-1, prod(stripe_size), C)  # nW*B, wh*ww, C
        anchor = window_partition(anchor, anchor_stripe_size)
        anchor = anchor.view(-1, prod(anchor_stripe_size), C // 3)

        B_, N1, _ = qkv.shape
        N2 = anchor.shape[1]
        qkv = qkv.reshape(B_, N1, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        anchor = anchor.reshape(B_, N2, self.num_heads, -1).permute(0, 2, 1, 3)

        # attention
        x = self.attn(
            anchor, k, v, self.attn_transform1, table, index_a2w, mask_a2w, False
        )
        x = self.attn(q, anchor, x, self.attn_transform2, table, index_w2a, mask_w2a)

        # merge windows
        x = x.view(B_, *stripe_size, C // 3)
        x = window_reverse(x, stripe_size, x_size)  # B H' W' C

        # reverse the shift
        if self.stripe_shift:
            x = torch.roll(x, shifts=shift_size, dims=(1, 2))

        x = x.view(B, H * W, C // 3)
        return x

    def extra_repr(self) -> str:
        return (
            f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, "
            f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}, anchor_window_down_factor={self.anchor_window_down_factor}"
        )

    def flops(self, N):
        pass


class MixedAttention(nn.Module):
    r"""Mixed window attention and stripe attention
    Args:
        dim (int): Number of input channels.
        stripe_size (tuple[int]): The height and width of the stripe.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
        pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
    """

    def __init__(
        self,
        dim,
        input_resolution,
        num_heads_w,
        num_heads_s,
        window_size,
        window_shift,
        stripe_size,
        stripe_groups,
        stripe_shift,
        qkv_bias=True,
        qkv_proj_type="linear",
        anchor_proj_type="separable_conv",
        anchor_one_stage=True,
        anchor_window_down_factor=1,
        attn_drop=0.0,
        proj_drop=0.0,
        pretrained_window_size=[0, 0],
        pretrained_stripe_size=[0, 0],
        args=None,
    ):

        super(MixedAttention, self).__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.args = args
        # print(args)
        self.qkv = QKVProjection(dim, qkv_bias, qkv_proj_type, args)
        # anchor is only used for stripe attention
        self.anchor = AnchorProjection(
            dim, anchor_proj_type, anchor_one_stage, anchor_window_down_factor, args
        )

        self.window_attn = WindowAttention(
            input_resolution,
            window_size,
            num_heads_w,
            window_shift,
            attn_drop,
            pretrained_window_size,
            args,
        )
        self.stripe_attn = AnchorStripeAttention(
            input_resolution,
            stripe_size,
            stripe_groups,
            stripe_shift,
            num_heads_s,
            attn_drop,
            pretrained_stripe_size,
            anchor_window_down_factor,
            args,
        )
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, x_size, table_index_mask):
        """
        Args:
            x: input features with shape of (B, L, C)
            stripe_size: use stripe_size to determine whether the relative positional bias table and index
            need to be regenerated.
        """
        B, L, C = x.shape

        # qkv projection
        qkv = self.qkv(x, x_size)
        qkv_window, qkv_stripe = torch.split(qkv, C * 3 // 2, dim=-1)
        # anchor projection
        anchor = self.anchor(x, x_size)

        # attention
        x_window = self.window_attn(
            qkv_window, x_size, *self._get_table_index_mask(table_index_mask, True)
        )
        x_stripe = self.stripe_attn(
            qkv_stripe,
            anchor,
            x_size,
            *self._get_table_index_mask(table_index_mask, False),
        )
        x = torch.cat([x_window, x_stripe], dim=-1)

        # output projection
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def _get_table_index_mask(self, table_index_mask, window_attn=True):
        if window_attn:
            return (
                table_index_mask["table_w"],
                table_index_mask["index_w"],
                table_index_mask["mask_w"],
            )
        else:
            return (
                table_index_mask["table_s"],
                table_index_mask["index_a2w"],
                table_index_mask["index_w2a"],
                table_index_mask["mask_a2w"],
                table_index_mask["mask_w2a"],
            )

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}"

    def flops(self, N):
        pass

Rethinking Local Perception in Lightweight Vision Transformer, arXiv2023

Interpretation: Plug and Play Series | Tsinghua University proposes the latest high-efficiency mobile network architecture CloFormer: the perfect fusion of attention mechanism and convolution! (qq.com)

Paper: https://arxiv.org/abs/2303.17803

Code: https://github.com/qhfan/CloFormer

introduce

This article mainly introduces a lightweight Vision Transformerarchitecture - CloFormer, which is used to process image tasks on the mobile terminal. CloFormer Introduced  AttnConv, which is a module combining attention mechanism and convolution operation, which can capture high-frequency local information. Compared with traditional convolution operations, AttnConv uses shared weights and context-aware weights, which can better handle the relationship between different positions in the image. Experimental results show that CloFormer has superior performance in image classification, object detection and semantic segmentation tasks.

Much existing work focuses on exploring lightweight visual transformers. From the perspective of frequency-domain coding, this paper argues that most of the existing lightweight models only focus on designing sparse attention to effectively process low-frequency global information, while using relatively simple methods to process high-frequency local information. Specifically, most models, such as EdgeViT and MobileViT, simply use raw convolutions to extract local representations, and only use globally shared weights in convolutions to process high-frequency local information. Other methods, such as LVT, first expand the markers into a window, and then use the attention within the window to obtain high-frequency information. These methods only use context-aware weights specific to each token for local awareness. 

Although the aforementioned lightweight models perform remarkably well on multiple datasets, none of the methods attempt to design more efficient methods that take advantage of shared and context-aware weights to handle high-frequency local information. Methods based on shared weights, such as traditional convolutional neural networks, are characterized by translation equivariance. Unlike them, methods based on context-aware weights, such as LVT and NAT, have weights that can vary with input content. Both types of weights have their own advantages in local perception.

Inspired by this, this paper designs a lightweight visual transformer - CloFormer, which exploits context-aware local enhancement. In particular, CloFormer uses a two-branch design structure.

local branch

In the local branch, this paper introduces a well-designed AttnConv, a simple yet effective convolution operator in the style of attention mechanism. The proposed AttnConv effectively fuses shared weights and context-aware weights to aggregate high-frequency local information. Specifically, AttnConv first extracts local representations using depthwise convolution (DWconv), where DWconv has shared weights. Then, it uses context-aware weights to enhance local features. Different from methods such as Non-Local to generate context-aware weights, AttnConv uses a gating mechanism to generate context-aware weights, which introduces stronger nonlinearity than commonly used attention mechanisms. In addition, AttnConv applies the convolution operator to Query and Key to aggregate local information, then calculates the Hadamard product of Q and K, and performs a series of linear or nonlinear transformations on the result, generating a range between [-1,1] context-aware weights between It is worth noting that AttnConv inherits the translation equivariance of convolutions, since all its operations are based on convolutions.

global branch

In the global branch, a traditional attention mechanism is used, but K and V are down-sampled to reduce computation, thereby capturing low-frequency global information. Finally, CloFormer fuses the outputs of the local branch and the global branch through a simple method, so that the model can capture high-frequency and low-frequency information simultaneously. In general, the design of CloFormer can take advantage of shared weights and context-aware weights at the same time, improve its local perception ability, and achieve excellent performance in multiple visual tasks such as image classification, object detection and semantic segmentation. .

CloFormer Network

CloFormer consists of a convolutional backbone and four stages, and each stage is stacked by a combination of Clo block and ConvFFN. Specifically, the input image is first passed through the convolutional backbone to obtain a token representation. This backbone consists of four convolutions, each with strides 2, 2, 1, and 1 in sequence. Next, the tokens go through four Clo blocks and a ConvFFN stage to extract hierarchical features. Finally, global average pooling and fully connected layers are used to generate prediction results.

ConvFFN

In order to incorporate local information into the FFN process, this paper replaces the commonly used FFN with ConvFFN. The main difference between ConvFFN and the commonly used FFN is that ConvFFN uses depthwise convolution (DWconv) after the GELU activation function, which enables ConvFFN to aggregate local information. Due to the existence of DWconv, downsampling can be directly performed in ConvFFN without introducing the PatchMerge module. CloFormer uses two kinds of ConvFFN. The first is ConvFFN within a stage, which directly utilizes skip connections. The other is ConvFFN connecting two stages, which is mainly used for downsampling operation.

Clo block

The Clo block in CloFormer is a very critical component. Each Clo nlock consists of a local branch and a global branch. In the global branch, K and V are first down-sampled, followed by standard attention operations on Q, K, and V to extract low-frequency global information.

Although the global branch is able to obtain a global receptive field, it is insufficient in dealing with high-frequency local information. To this end, CloFormer introduces a local branch and uses AttnConv to process high-frequency local information. AttnConv can fuse shared weights and context-aware weights, which can better handle high-frequency local information. Thus, CloFormer combines global and local advantages.

AttnConv

AttnConv is a key module that enables the proposed model to achieve high performance. It combines some standard attention operations. Specifically, in AttnConv, we first perform a linear transformation to obtain Q, K, and V. After the linear transformation, local feature aggregation with shared weights is performed on V. Then, based on the processed V and Q, K performs context-aware local enhancement. Three steps:

  1. Use DWConv to perform local feature aggregation on V;
  2. Use DWConv to aggregate local features of Q and K, and then combine Q and K to generate context-aware weights, which are multiplied by V to enhance local features;
  3. The features of the local branch and the global branch are spliced ​​and fused.

The context-aware weights in AttnConv allow the model to adapt to the input better than traditional convolutions. Compared with the local self-attention mechanism, the introduction of shared weights enables the model to better handle high-frequency information, thereby improving performance. In addition, the method of generating context-aware weights introduces stronger nonlinearity and also improves performance. It should be noted that all operations in AttnConv are based on convolution, maintaining the translational equivariance of convolution.

experiment

key code

AttnConv

# https://github.com/qhfan/CloFormer/blob/main/classification/models/blocks.py

class AttnMap(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.act_block = nn.Sequential(
                            nn.Conv2d(dim, dim, 1, 1, 0),
                            MemoryEfficientSwish(),
                            nn.Conv2d(dim, dim, 1, 1, 0)
                            #nn.Identity()
                         )
    def forward(self, x):
        return self.act_block(x)

class EfficientAttention(nn.Module):

    def __init__(self, dim, num_heads, group_split: List[int], kernel_sizes: List[int], window_size=7, 
                 attn_drop=0., proj_drop=0., qkv_bias=True):
        super().__init__()
        assert sum(group_split) == num_heads
        assert len(kernel_sizes) + 1 == len(group_split)
        self.dim = dim
        self.num_heads = num_heads
        self.dim_head = dim // num_heads
        self.scalor = self.dim_head ** -0.5
        self.kernel_sizes = kernel_sizes
        self.window_size = window_size
        self.group_split = group_split
        convs = []
        act_blocks = []
        qkvs = []
        #projs = []
        for i in range(len(kernel_sizes)):
            kernel_size = kernel_sizes[i]
            group_head = group_split[i]
            if group_head == 0:
                continue
            convs.append(nn.Conv2d(3*self.dim_head*group_head, 3*self.dim_head*group_head, kernel_size,
                         1, kernel_size//2, groups=3*self.dim_head*group_head))
            act_blocks.append(AttnMap(self.dim_head*group_head))
            qkvs.append(nn.Conv2d(dim, 3*group_head*self.dim_head, 1, 1, 0, bias=qkv_bias))
            #projs.append(nn.Linear(group_head*self.dim_head, group_head*self.dim_head, bias=qkv_bias))
        if group_split[-1] != 0:
            self.global_q = nn.Conv2d(dim, group_split[-1]*self.dim_head, 1, 1, 0, bias=qkv_bias)
            self.global_kv = nn.Conv2d(dim, group_split[-1]*self.dim_head*2, 1, 1, 0, bias=qkv_bias)
            #self.global_proj = nn.Linear(group_split[-1]*self.dim_head, group_split[-1]*self.dim_head, bias=qkv_bias)
            self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size!=1 else nn.Identity()

        self.convs = nn.ModuleList(convs)
        self.act_blocks = nn.ModuleList(act_blocks)
        self.qkvs = nn.ModuleList(qkvs)
        self.proj = nn.Conv2d(dim, dim, 1, 1, 0, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def high_fre_attntion(self, x: torch.Tensor, to_qkv: nn.Module, mixer: nn.Module, attn_block: nn.Module):
        '''
        x: (b c h w)
        '''
        b, c, h, w = x.size()
        qkv = to_qkv(x) #(b (3 m d) h w)
        qkv = mixer(qkv).reshape(b, 3, -1, h, w).transpose(0, 1).contiguous() #(3 b (m d) h w)
        q, k, v = qkv #(b (m d) h w)
        attn = attn_block(q.mul(k)).mul(self.scalor)
        attn = self.attn_drop(torch.tanh(attn))
        res = attn.mul(v) #(b (m d) h w)
        return res
        
    def low_fre_attention(self, x : torch.Tensor, to_q: nn.Module, to_kv: nn.Module, avgpool: nn.Module):
        '''
        x: (b c h w)
        '''
        b, c, h, w = x.size()
        
        q = to_q(x).reshape(b, -1, self.dim_head, h*w).transpose(-1, -2).contiguous() #(b m (h w) d)
        kv = avgpool(x) #(b c h w)
        kv = to_kv(kv).view(b, 2, -1, self.dim_head, (h*w)//(self.window_size**2)).permute(1, 0, 2, 4, 3).contiguous() #(2 b m (H W) d)
        k, v = kv #(b m (H W) d)
        attn = self.scalor * q @ k.transpose(-1, -2) #(b m (h w) (H W))
        attn = self.attn_drop(attn.softmax(dim=-1))
        res = attn @ v #(b m (h w) d)
        res = res.transpose(2, 3).reshape(b, -1, h, w).contiguous()
        return res

    def forward(self, x: torch.Tensor):
        '''
        x: (b c h w)
        '''
        res = []
        for i in range(len(self.kernel_sizes)):
            if self.group_split[i] == 0:
                continue
            res.append(self.high_fre_attntion(x, self.qkvs[i], self.convs[i], self.act_blocks[i]))
        if self.group_split[-1] != 0:
            res.append(self.low_fre_attention(x, self.global_q, self.global_kv, self.avgpool))
        return self.proj_drop(self.proj(torch.cat(res, dim=1)))

class ConvFFN(nn.Module):

    def __init__(self, in_channels, hidden_channels, kernel_size, stride,
                 out_channels, act_layer=nn.GELU, drop_out=0.):
        super().__init__()
        self.fc1 = nn.Conv2d(in_channels, hidden_channels, 1, 1, 0)
        self.act = act_layer()
        self.dwconv = nn.Conv2d(hidden_channels, hidden_channels, kernel_size, stride, 
                                kernel_size//2, groups=hidden_channels)
        self.fc2 = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0)
        self.drop = nn.Dropout(drop_out)

    def forward(self, x: torch.Tensor):
        '''
        x: (b h w c)
        '''
        x = self.fc1(x)
        x = self.act(x)
        x = self.dwconv(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class EfficientBlock(nn.Module):

    def __init__(self, dim, out_dim, num_heads, group_split: List[int], kernel_sizes: List[int], window_size: int,
                 mlp_kernel_size: int, mlp_ratio: int, stride: int, attn_drop=0., mlp_drop=0., qkv_bias=True,
                 drop_path=0.):
        super().__init__()
        self.dim = dim
        self.mlp_ratio = mlp_ratio
        self.norm1 = nn.GroupNorm(1, dim)
        self.attn = EfficientAttention(dim, num_heads, group_split, kernel_sizes, window_size,
                                       attn_drop, mlp_drop, qkv_bias)
        self.drop_path = DropPath(drop_path)
        self.norm2 = nn.GroupNorm(1, dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.stride = stride
        if stride == 1:
            self.downsample = nn.Identity()
        else:
            self.downsample = nn.Sequential(
                                nn.Conv2d(dim, dim, mlp_kernel_size, 2, mlp_kernel_size//2),
                                nn.SyncBatchNorm(dim),
                                nn.Conv2d(dim, out_dim, 1, 1, 0),
                            )
        self.mlp = ConvFFN(dim, mlp_hidden_dim, mlp_kernel_size, stride, out_dim, 
                        drop_out=mlp_drop)
    def forward(self, x: torch.Tensor):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = self.downsample(x) + self.drop_path(self.mlp(self.norm2(x)))
        return x

if __name__ == '__main__':
    input = torch.randn(4, 96, 56, 56)
    model = EfficientBlock(96, 192, 3, [1, 1, 1], [7, 5], 7, 7, 4, 2)
    print(model(input).size())

Guess you like

Origin blog.csdn.net/m0_61899108/article/details/131156775