【论文解读】FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization

1.介绍

 最先进的延迟-精度权衡

在两个广泛使用的平台——移动设备和桌面GPU上的延迟是最快的

1.1 之前存在的问题

由于内存访问成本的增加,跳过连接(skip connections)在延迟方面造成了很大的开销

1.1.1 跳过连接 为什么 会增加内存访问成本?

这主要与以下两个因素有关:

  1. 内存需求:跳过连接会引入额外的特征映射或张量,从而增加了模型内存的使用。每个跳过连接需要存储一定数量的特征映射,这可能导致更多的内存占用。虽然这种增加通常不会显著影响模型的总体内存占用,但在某些情况下可能会有一些额外的内存压力。

  2. 内存访问:在深度神经网络中,计算和存储特征映射需要大量的内存访问。跳过连接增加了需要读取和写入的特征映射数量,这可能会导致内存访问成本的增加。

1.2 FastViT 改进之处

  • 引入了完全可重新参数化的RepMixer来删除跳过连接。
  • 将所有密集的k×k卷积替换为它们的分解版本,即逐通道卷积和逐点卷积。使用了线性训练时间过参数化( linear train-time overparameterization)。这些额外的分支仅在训练期间引入,并在推理时重新参数化。
  • 在早期阶段使用大卷积核来替代自关注层

1.3 前置知识

  • 解耦 训练 和 推理
  • 参数重参数化
  • 使用大核卷积的好处(为什么可以使用大核卷积替代自关注层)

2. FastViT

 (a)解耦训练时间和推理时间架构的FastViT架构概述。阶段1、2和3具有相同的体系结构,并使用RepMixer进行令牌混合。在阶段4中,自关注层用于令牌混合。(b)卷积系统的结构。(c)卷积- ffn的体系结构(d) RepMixer块概述,它在推理时重新参数化跳过连接。

基本内容已经在上图片上了

【论文解读参考】解读模型压缩24:FastViT:快速卷积 Transformer 的混合视觉架构 - 知乎 (zhihu.com)

3.代码

3.1 不同型号 

FastViT 一共有6个型号。

3.2 RepCPE

 在fastvit_sa12,fastvit_sa24,fastvit_sa36,fastvit_ma36等模型中

pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7))]

在推理时将原先的残差分支去除,只看前向推理函数

#reparam_conv 和 self.pe其实是参数基本相同的卷积
 def forward(self, x: torch.Tensor) -> torch.Tensor:
        if hasattr(self, "reparam_conv"):
            x = self.reparam_conv(x)
            return x
        else:
            x = self.pe(x) + x
            return x

 3.3 basic_blocks

 token_mixer有两种选择,RepMixerBlock和AttentionBlock,一般也只会在最后一个层结构中使用AttentionBlock

3.3.1 RepMixerBlock

 和其他的Mixer结构相似

3.3.1.1 RepMixer

(1)训练结构

 残差结构的实现是通过带分支结构的MobileOneBlock减去不带分支结构的MobileOneBlock实现的

x = x + self.mixer(x) - self.norm(x)
(2)重参数

 重参数的原理和RepVGG相同,不过多赘述了

3.3.1.2 ConvFFN

其实也相对简单,值得一提的是ConvFFN似乎不进行重参数化

 3.4 PatchEmbed

区别一般的,通过一个ReparamLargeKernelConv和一个MobileOneBlock进行映射

ReparamLargeKernelConv来自Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNshttps://openaccess.thecvf.com/content/CVPR2022/papers/Ding_Scaling_Up_Your_Kernels_to_31x31_Revisiting_Large_Kernel_Design_CVPR_2022_paper.pdf

class PatchEmbed(nn.Module):
    """Convolutional patch embedding layer."""

    def __init__(
        self,
        patch_size: int,
        stride: int,
        in_channels: int,
        embed_dim: int,
        inference_mode: bool = False,
    ) -> None:

        super().__init__()
        block = list()
        block.append(
            ReparamLargeKernelConv(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=stride,
                groups=in_channels,
                small_kernel=3,
                inference_mode=inference_mode,
            )
        )
        block.append(
            MobileOneBlock(
                in_channels=embed_dim,
                out_channels=embed_dim,
                kernel_size=1,
                stride=1,
                padding=0,
                groups=1,
                inference_mode=inference_mode,
                use_se=False,
                num_conv_branches=1,
            )
        )
        self.proj = nn.Sequential(*block)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        return x

猜你喜欢

转载自blog.csdn.net/weixin_50862344/article/details/132397931