MaxViT:多轴视觉Transformer

        论文链接:https://arxiv.org/abs/2204.01697

        代码链接:https://github.com/google-research/maxvit

        如果进入不了github就直接在这里下载,不过没有权重文件,免费的:https://download.csdn.net/download/weixin_44911037/86823798

        这是一篇谷歌发表在ECCV2022的论文,这篇论文可以说是提供了一个即插即用的模块(个人觉得),该模块将CNN与Transformer相结合。

        众所周知,Transformer在图像领域取得了令人瞩目的结果,但是如果没有广泛的预训练,ViT在图像识别方面表现不佳。这是由于Transformer的模型能力强,灌输较少的归纳偏差,导致过拟合,以此在测试集上效果差,并且关于图像大小的自注意机制缺乏可扩展性(因为模型大小的问题,窗口越大计算量越大),限制了它们在最先进的视觉骨干中的广泛采用。该论文就提出了一种高效的、可扩展的多轴注意模型Max-SA,该模型由阻塞局部注意(窗口注意力)扩展全局注意(网格注意两个方面组成。这些设计选择允许在任意输入分辨率上进行全局-局部空间交互,只具有线性复杂度,要知道线性复杂的在很多论文中达不到,毕竟自注意力机制的复杂度是N的平方。     那么为了将该模块应用在计算机视觉任务中,还提出了一个新的架构元素,并相应地提出了一个简单的分层视觉主干,称为MaxViT,通过在多个阶段重复基本构建块。

        呈现出来的效果是:在图像分类上,MaxViT在各种设置下都达到了最先进的性能:在没有额外数据的情况下,MaxViT达到了86.5%的ImageNet-1K top-1精度;使用ImageNet-21K预训练,我们的模型达到了88.7%的顶级精度,如下图所示。并且在一些下游任务中,都得到了不错的效果 。

该论文所提出的贡献:

1、一个通用的强Transformer骨干,MaxViT,它可以在网络的每个阶段捕获本地和全局空间交互

2、由闭塞的局部注意扩张的全局注意组成的新颖独立的多轴注意模块,在线性复杂性中享受全局感知

3、通过广泛的消融研究(即最终的成果不是一蹴而就的),我们展示了大量的设计选择,包括层数、布局、MBConv的使用等,最终汇聚成我们最终的模块化设计,MaxViT - Block。

4、我们的大量实验表明,MaxViT在各种数据制度下实现了SOTA结果,包括图像分类,目标检测,图像美学评估和图像生成。

        接下来我们来看看他所涉及的MaxViT整体分类模型架构,当然在应用一下下游任务的时候,这个结构的某些超参数是需要改变的,我们可以看出,该网络结构首先通过S0(一般的卷积块)进行下采样,再通过S1,S2,S3,S4个(MaxViT - Block)模块分别不断地重复进行下采样,然后再通过池化、全连接得到最终的输出结果。其实这样的操作和ResNet非常相似。因此我们可以知道本文最终要的结构就是MaxViT - Block。下面就开始分析该MaxViT - Block模块。

        接下来,我们来看MaxViT - Block模块,该模块分为三个部分,分别是MBConv、Block Attention、Grid Attention。与局部卷积相比,全局交互作用是自我注意的主要优势之一。然而,直接沿整个空间施加注意在计算上是不可行的,因为注意算符要求二次复杂度,因为一般的自注意力机制需要和所有像素进行计算。为了解决这个问题,提出了一种多轴方法,通过简单地分解空间轴,将全尺寸注意力分解为两种稀疏形式——局部和全局。如何做的呢?

        假设输入的一个特征图为CxHxW,将该特征图维度转成CxHW,然后再对空间维度进行分割成不同的块,加入每个块为P(和下图保持一致性令P=4,论文中的代码是7),那么得到CxH/4xW/4x4x4=(C,H/4xW/4,16),这就代表每个窗口的像素值为16,一共有HW/16个窗口,此时再将自注意力机制应用在划分的每个窗口中进行,这样就获得了局部自注意力机制。那么对于网格注意力机制是如何做的呢?从刚刚获得的结果知道,特征图被划分成这样(C,H/4xW/4,16),因此我们可以通过不同窗口来进行计算这样全局注意力机制,下图左边不是划分四个窗口嘛,所以我们可以对让这4个窗口进行自注意力机制进行计算,但是有一个问题就是这个窗口的数目是变化的,等于H/4,W/4,那么如果输入的特征图W和H特别大,那么得到窗口数目也就非常大,所以文中就给出另一种方法,将输入的一张特征图CxHxW,通过规定的G个窗口(论文中提供的G也为7,但是在本文中我们就根据下图来设定4,这样好理解一点)来进行计算,转成(C,GxG,H/GxW/G),这样的话,不管高和宽如何变化,我们最终特征图在空间上就只会划分规定的窗口,这样就会减少计算量了,如下右图所示,相同颜色进行自注意力机制计算。这一块不是特别好理解,比较绕,但是个人觉得这一部分内容相比较Swin-Transformer的shift window更好理解一点。如果你想要知道更具体的内容,可以看极市邀请作者涂正中所作出的论文分享。极市直播第99期|ECCV2022-涂正中:让谷歌的骨干网络MaxViT治好你的科研内耗_哔哩哔哩_bilibili

 下面在看看该网络结构的相关变体:

 代码是用tensorflow写的,读不懂就算了,自己再去网上找一下Pytorch开源的。

class MaxViTBlock(tf.keras.layers.Layer):
  """MaxViT block = MBConv + Block-Attention + FFN + Grid-Attention + FFN."""

  def _retrieve_config(self, config):
    required_keys = ['hidden_size', 'head_size', 'window_size', 'grid_size']
    optional_keys = {
        'num_heads': None,
        'expansion_rate': 4,
        'activation': 'gelu',
        'pool_type': 'avg',
        'pool_stride': 1,
        'dropatt': None,
        'dropout': None,
        'rel_attn_type': '2d_multi_head',
        'scale_ratio': None,
        'survival_prob': None,
        'ln_epsilon': 1e-5,
        'ln_dtype': None,
        'kernel_initializer': tf.random_normal_initializer(stddev=0.02),
        'bias_initializer': tf.zeros_initializer,
    }
    config = create_config_from_dict(config, required_keys, optional_keys)
    return config

  def __init__(self, config, name='transformer'):
    super().__init__(name=name)

    self._config = self._retrieve_config(config)

  def build(self, input_shape):
    config = self._config

    input_size = input_shape.as_list()[-1]

    if input_size != config.hidden_size:
      self._shortcut_proj = TrailDense(
          config.hidden_size,
          kernel_initializer=config.kernel_initializer,
          bias_initializer=config.bias_initializer,
          name='shortcut_proj')
    else:
      self._shortcut_proj = None

    self._block_attn_layer_norm = tf.keras.layers.LayerNormalization(
        axis=-1,
        epsilon=config.ln_epsilon,
        dtype=config.ln_dtype,
        name='attn_layer_norm')

    self._grid_attn_layer_norm = tf.keras.layers.LayerNormalization(
        axis=-1,
        epsilon=config.ln_epsilon,
        dtype=config.ln_dtype,
        name='attn_layer_norm_1')

    self._block_attention = Attention(
        config.hidden_size,
        config.head_size,
        num_heads=config.num_heads,
        dropatt=config.dropatt,
        rel_attn_type=config.rel_attn_type,
        scale_ratio=config.scale_ratio,
        kernel_initializer=config.kernel_initializer,
        bias_initializer=config.bias_initializer,
        name='attention')

    self._grid_attention = Attention(
        config.hidden_size,
        config.head_size,
        num_heads=config.num_heads,
        dropatt=config.dropatt,
        rel_attn_type=config.rel_attn_type,
        scale_ratio=config.scale_ratio,
        kernel_initializer=config.kernel_initializer,
        bias_initializer=config.bias_initializer,
        name='attention_1')

    self._block_ffn_layer_norm = tf.keras.layers.LayerNormalization(
        axis=-1,
        epsilon=config.ln_epsilon,
        dtype=config.ln_dtype,
        name='ffn_layer_norm')

    self._grid_ffn_layer_norm = tf.keras.layers.LayerNormalization(
        axis=-1,
        epsilon=config.ln_epsilon,
        dtype=config.ln_dtype,
        name='ffn_layer_norm_1')

    self._block_ffn = FFN(
        config.hidden_size,
        dropout=config.dropout,
        expansion_rate=config.expansion_rate,
        activation=config.activation,
        kernel_initializer=config.kernel_initializer,
        bias_initializer=config.bias_initializer,
        name='ffn')

    self._grid_ffn = FFN(
        config.hidden_size,
        dropout=config.dropout,
        expansion_rate=config.expansion_rate,
        activation=config.activation,
        kernel_initializer=config.kernel_initializer,
        bias_initializer=config.bias_initializer,
        name='ffn_1')

    self._mbconv = MBConvBlock(config)

  def downsample(self, inputs, name):
    config = self._config
    output = inputs
    if config.pool_stride > 1:
      output = ops.maybe_reshape_to_2d(output)
      output = ops.pooling_2d(output,
                              config.pool_type,
                              config.pool_stride,
                              padding='same',
                              data_format='channels_last',
                              name=name)
    return output

  def window_partition(self, features):
    """Partition the input feature maps into non-overlapping windows.

    Args:
      features: [B, H, W, C] feature maps.

    Returns:
      Partitioned features: [B, nH, nW, wSize, wSize, c].

    Raises:
      ValueError: If the feature map sizes are not divisible by window sizes.
    """
    config = self._config
    _, h, w, c = features.shape
    window_size = config.window_size

    if h % window_size != 0 or w % window_size != 0:
      raise ValueError(f'Feature map sizes {(h, w)} '
                       f'not divisible by window size ({window_size}).')

    features = tf.reshape(features, (-1,
                                     h // window_size, window_size,
                                     w // window_size, window_size, c))
    features = tf.transpose(features, (0, 1, 3, 2, 4, 5))
    features = tf.reshape(features, (-1, window_size, window_size, c))
    return features

  def window_stitch_back(self, features, window_size, h, w):
    """Reverse window_partition."""
    features = tf.reshape(features, [
        -1, h // window_size, w // window_size, window_size, window_size,
        features.shape[-1]
    ])
    return tf.reshape(
        tf.transpose(features, (0, 1, 3, 2, 4, 5)),
        [-1, h, w, features.shape[-1]])

  def grid_partition(self, features):
    """Partition the input feature maps into non-overlapping windows.

    Args:
      features: [B, H, W, C] feature maps.

    Returns:
      Partitioned features: [B, nH, nW, wSize, wSize, c].

    Raises:
      ValueError: If the feature map sizes are not divisible by window sizes.
    """
    config = self._config
    _, h, w, c = features.shape
    grid_size = config.grid_size
    if h % grid_size != 0 or w % grid_size != 0:
      raise ValueError(f'Feature map sizes {(h, w)} '
                       f'not divisible by window size ({grid_size}).')
    features = tf.reshape(features, (-1,
                                     grid_size, h // grid_size,
                                     grid_size, w // grid_size, c))
    features = tf.transpose(features, (0, 2, 4, 1, 3, 5))
    features = tf.reshape(features, (-1, grid_size, grid_size, c))
    return features

  def grid_stitch_back(self, features, grid_size, h, w):
    """Reverse window_partition."""
    features = tf.reshape(features, [
        -1, h // grid_size, w // grid_size, grid_size,
        grid_size, features.shape[-1]
    ])
    return tf.reshape(
        tf.transpose(features, (0, 3, 1, 4, 2, 5)),
        [-1, h, w, features.shape[-1]])

  def block_shortcut_branch(self, shortcut):
    return shortcut

  def grid_shortcut_branch(self, shortcut):
    return shortcut

  def mbconv_shortcut_branch(self, shortcut):
    shortcut = self.downsample(shortcut, 'shortcut_pool')
    if self._shortcut_proj:
      shortcut = self._shortcut_proj(shortcut)

  def block_attn_branch(self, inputs, training, attn_mask):
    config = self._config
    output = self._block_attn_layer_norm(inputs)
    # If put grid-attention in front, we don't need to downsample.
    # Apply local block-attention
    _, h, w, _ = output.shape
    output = self.window_partition(output)
    output = ops.maybe_reshape_to_1d(output)
    output = self._block_attention(output, training, attn_mask=attn_mask)
    output = self.window_stitch_back(output, config.window_size, h, w)
    return output

  def grid_attn_branch(self, inputs, training, attn_mask):
    config = self._config
    output = self._grid_attn_layer_norm(inputs)
    # Apply global grid
    _, h, w, _ = output.shape
    output = self.grid_partition(output)
    output = ops.maybe_reshape_to_1d(output)
    output = self._grid_attention(output, training, attn_mask=attn_mask)
    output = self.grid_stitch_back(output, config.grid_size, h, w)
    return output

  def block_ffn_branch(self, inputs, training):
    output = self._block_ffn_layer_norm(inputs)
    output = self._block_ffn(output, training)
    return output

  def grid_ffn_branch(self, inputs, training):
    output = self._grid_ffn_layer_norm(inputs)
    output = self._grid_ffn(output, training)
    return output

  def mbconv_branch(self, inputs, training):
    output = self._mbconv(inputs, training=training)
    return output

  def call(self, inputs, training, attn_mask=None):
    logging.debug('Block %s input shape: %s (%s)', self.name, inputs.shape,
                  inputs.dtype)

    config = self._config

    # MBConv
    output = self.mbconv_branch(inputs, training)

    # block self-attention
    shortcut = output
    output = self.block_attn_branch(output, training, attn_mask)
    if config.dropout:
      output = tf.keras.layers.Dropout(
          config.dropout, name='after_block_attn_drop')(
              output, training=training)
    output = ops.residual_add(output, shortcut, config.survival_prob, training)

    shortcut = output
    output = self.block_ffn_branch(output, training)
    if config.dropout:
      output = tf.keras.layers.Dropout(
          config.dropout, name='after_block_ffn_drop_1')(
              output, training=training)
    output = ops.residual_add(output, shortcut, config.survival_prob, training)

    # grid self-attention
    shortcut = output
    output = self.grid_attn_branch(output, training, attn_mask)
    if config.dropout:
      output = tf.keras.layers.Dropout(
          config.dropout, name='after_grid_attn_drop')(
              output, training=training)
    output = ops.residual_add(output, shortcut, config.survival_prob, training)

    shortcut = output
    output = self.grid_ffn_branch(output, training)
    if config.dropout:
      output = tf.keras.layers.Dropout(
          config.dropout, name='after_grid_ffn_drop')(
              output, training=training)
    output = ops.residual_add(output, shortcut, config.survival_prob, training)

    return output

猜你喜欢

转载自blog.csdn.net/weixin_44911037/article/details/127530019