【神经网络】(22) ConvMixer 代码复现,网络解析,附TensorFlow完整代码

大家好,今天和各位分享一下如何使用 TensorFlow 构建 ConvMixer 卷积神经网络模型.

我偶然间找到了这个网络,这是一个实现起来非常简单的模型,但是能够实现较好的精度表现,超过了 Vision Transformer 模型,有种大道至简的感觉。

论文地址:https://openreview.net/forum?id=TVHS5Y4dNvM


1. 引言

近年来 Transformer 模型在 CV 领域中不断挑战卷积神经网络的统治地位,出现了能和 CNN 扳手腕的 VisionTransformer 以及划时代的 SwinTransformer。这篇文章作者主要针对的是 VIT 模型,他提出了一个问题:ViT的性能是由于其强大的Transformer结构产生的,还是由于使用patch作为输入表示产生的

在论文中,作者证明了PatchEmbedding对VIT的精度影响更大,并提出了一个非常简单的模型ConvMixer,在思想上类似于ViT和MLP-Mixer。模型直接将patch作为输入,分离空间和通道尺寸的混合建模并在整个网络中保持相同大小的分辨率

尽管ConvMixer的设计很简单,但是实验证明了ConvMixer在相似的参数计数和数据集大小方面优于ViT、MLP-Mixer及其一些变体,以及经典的视觉模型,如ResNet。


2. 模型构建

我们先导入需要用到的工具包

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

2.1 Patch Embedding

patchembedding 的主要功能是对原始输入图像(h, w)划分图像块。首先指定每个图像块的size为(patch_size, patch_size)将每张图像划分出(h//patch_size, w//patch_size)个图像块

它的实现方法就是通过一个 kernel_size 和 stride 都等于 patch_size 的卷积层来划分图像块

代码如下:

# ---------------------------------------------- #
#(1)patchembedding层
'''out_channel代表输出通道数, patch_size代表每个图像块的宽高'''
# ---------------------------------------------- #
def patchembed(inputs, out_channel, patch_size):
    
    # 卷积核大小为patch_size*patch_size,步长为patch_size的标准卷积划分图像块
    x = layers.Conv2D(filters = out_channel,   # 输出通道数
                      kernel_size = patch_size,  # 卷积核尺寸
                      strides = patch_size,  # 卷积步长
                      padding = 'same',  # 
                      use_bias = False)(inputs)

    # GELU激活函数、BN标准化
    x = layers.Activation('gelu')(x)
    x = layers.BatchNormalization()(x)

    return x

2.2 特征提取层

这里的特征提取层由三部分组成,深度卷积(depthwise conv)、逐点卷积(pointwise conv)、残差连接(shortcut)。如下图ConvMixer Layer所示。

关于深度可分离卷积的原理,看我这篇博文:https://blog.csdn.net/dgvv4/article/details/123476899

首先输入特征图,经过深度卷积提取特征图长宽方向的信息,其中卷积核的个数和输入特征图的通道数相同,且输入和输出特征图的shape相同;接着残差连接输入和输出;然后经过1*1逐点卷积融合通道方向的信息,其中卷积核的个数和输出特征图的个数相同

代码如下:

# ---------------------------------------------- #
#(2)单个特征提取模块
'''out_channel代表逐点卷积的输出通道数, kernel_size代表深度卷积的卷积核大小'''
# ---------------------------------------------- #
def layer(inputs, out_channel, kernel_size):

    # 9*9深度卷积提取特征
    x = layers.DepthwiseConv2D(kernel_size = kernel_size,  # 卷积核大小
                               strides = 1,  # 不经过下采样
                               padding = 'same',  # 卷积前后size不变
                               use_bias = False)(inputs)
    # GELU激活、BN标准化
    x = layers.Activation('gelu')(x)
    x = layers.BatchNormalization()(x)

    # 残差连接
    x = x + inputs
    
    # 1*1逐点卷积
    x = layers.Conv2D(filters = out_channel,  # 输出通道数
                      kernel_size = 1,  # 1*1卷积
                      strides = 1)(x)
    # GELU激活、BN标准化
    x = layers.Activation('gelu')(x)
    x = layers.BatchNormalization()(x)

    return x

# ---------------------------------------------- #
#(3)堆叠多个特征提取模块
'''depth代表堆叠的次数'''
# ---------------------------------------------- #
def blocks(x, depth, out_channel, kernel_size):
    # 堆叠多个特征提取模块
    for _ in range(depth):
        x = layer(x, out_channel, kernel_size)
    
    return x

2.3 主干网络

ConvMixer的网络结构非常简单。首先图像经过 PatchEmbedding 划分图像块,然后经过12个特征提取模块,最后经过一个全连接层得到输出结果。

这里构建 ConvMixer-1536/20 网络模型,其中 1536 代表patchembedding 层的输出通道数20 代表堆叠20个特征提取模块每个图像块patch_size的大小为7*7特征提取模块中深度卷积的卷积核尺寸为 9*9

代码如下:

# ---------------------------------------------- #
#(4)主干网络
'''input_shape代表输入图像的尺寸(不包含batch维度), num_classes代表分类数'''
# ---------------------------------------------- #
def convmixer(input_shape, num_classes):

    # 构造输入层[b,224,224,3]
    inputs = keras.Input(shape=input_shape)
    # patchembedding层[b,224//7,224//7,1536]
    x = patchembed(inputs, out_channel=1536, patch_size=7)
    # 经过20个特征提取层[b,224//7,224//7,1536]
    x = blocks(x, depth=20, out_channel=1536, kernel_size=9)

    # 全局平均池化[b,1536]
    x = layers.GlobalAveragePooling2D()(x)
    # 全连接分类[b,num_classes]
    outputs = layers.Dense(num_classes)(x)

    # 构造网络
    model = keras.Model(inputs, outputs)

    return model

2.4 查看网络架构

以1000分类为例查看网络结构

# ---------------------------------------------- #
#(5)查看网络结构
# ---------------------------------------------- #
if __name__ == '__main__':
    # 接受模型
    model = convmixer(input_shape=[224,224,3],num_classes=1000)
    # 查看网络结构
    model.summary()

网络结构如下:

 conv2d_20 (Conv2D)             (None, 32, 32, 1536  2360832     ['tf.__operators__.add_19[0][0]']
                                )

 activation_40 (Activation)     (None, 32, 32, 1536  0           ['conv2d_20[0][0]']
                                )

 batch_normalization_40 (BatchN  (None, 32, 32, 1536  6144       ['activation_40[0][0]']
 ormalization)                  )

 global_average_pooling2d (Glob  (None, 1536)        0           ['batch_normalization_40[0][0]']
 alAveragePooling2D)

 dense (Dense)                  (None, 1000)         1537000     ['global_average_pooling2d[0][0]'
                                                                 ]
==================================================================================================
Total params: 51,719,656
Trainable params: 51,593,704
Non-trainable params: 125,952
__________________________________________________________________________________________________

猜你喜欢

转载自blog.csdn.net/dgvv4/article/details/125207966