tensorflow2.2_实现SENet

1. SENet介绍

    SENet 是 ImageNet Challenge 图像识别比赛 2017 年的冠军,是来自 Momenta 公司的团队完成。他们提出了 Squeeze-and-Excitation Networks(简称 SENet)。SENet一般不单独使用,通常都是与其它模型结合使用,使其效果更好。
    在一般的卷积层中通过卷积核会生成许多不同的特征图,但在这些特征图中并不是所有的特征图都是很重要的,也许有些特征可以忽略。如果我们可以将重要的特征加强,而不重要的特征可以减弱,这样我们的模型效果可能会更好。
    所以SENet就可以实现这样的效果,它的核心思想是:给特征图增加注意力和门控机制,增强重要的特征图的信息,减弱不重要的特征图的信息
在这里插入图片描述
其中:

  • W、H、C分别代表图片的宽、高、通道数。
  • Global Pooling代表全局池化。
  • FC代表全连接层
  • ReLu、Sigmoid分别代表激活函数分别使用ReLu和Sigmoid。
  • r代表缩减率,意思是在第一个全连接层缩减的通道数。

    如上图,左边是普通的残差结构,右边是加上了SENet的残差结构。
    加上SENet后,首先是做平均池化,得到特征图的压缩特征。第二层进行全连接层,我们也可以使用1x1的卷积核来代替,效果是一样的之后使用ReLu激活函数。第三层就是全连接层,也可以使用1x1卷积来代替,之后使用Sigmoid函数,使输出范围在0~1之间,起到门控的作用。Sigmoid输出的激活值最后会乘以初始残差结构最后一个卷积层的输出结果,对特征图的数值大小进行控制。如果是重要的特征图,会保持比较大的数值;如果是不重要的特征图,特征图的数值就会变小。
下图是论文中的图,和上面介绍的差不多。
在这里插入图片描述
下图是论文中三个网络的模型结构。
在这里插入图片描述
其都是在每个block后加上SENet。
下图是个模型的比较
在这里插入图片描述

2. SEnet实现代码

def SE_block(x_0, r = 16):
    channels = x_0.shape[-1]
    x = GlobalAvgPool2D()(x_0)
    # (?, ?) -> (?, 1, 1, ?)
    x = x[:, None, None, :]
    # 用2个1x1卷积代替全连接
    x = Conv2D(filters=channels//r, kernel_size=1, strides=1)(x)
    x = Activation('relu')(x)
    x = Conv2D(filters=channels, kernel_size=1, strides=1)(x)
    x = Activation('sigmoid')(x)
    x = Multiply()([x_0, x])
    
    return x

3. Resnet18与SENet结合

Resnet50也可以在残差块结构的最后加上SENet。

from os import name
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (Dense, ZeroPadding2D, Conv2D,
                                     MaxPool2D, GlobalAvgPool2D, Input,
                                     BatchNormalization, Activation, Add, Multiply)
from tensorflow.keras.models import Model
from plot_model import plot_model
from tensorflow.python.keras.layers.pooling import AveragePooling2D

def SE_block(x_0, r = 16):
    channels = x_0.shape[-1]
    x = GlobalAvgPool2D()(x_0)
    # (?, ?) -> (?, 1, 1, ?)
    x = x[:, None, None, :]
    # 用2个1x1卷积代替全连接
    x = Conv2D(filters=channels//r, kernel_size=1, strides=1)(x)
    x = Activation('relu')(x)
    x = Conv2D(filters=channels, kernel_size=1, strides=1)(x)
    x = Activation('sigmoid')(x)
    x = Multiply()([x_0, x])
    
    return x

# 结构快
def block(x, filters, strides=2, r=16, conv_short=True):
    if conv_short:
        short_cut = Conv2D(filters=filters, kernel_size=1, strides=strides, padding='valid')(x)
        short_cut = BatchNormalization(epsilon=1.001e-5)(short_cut)
    else:
        short_cut = x

    
    # 2层卷积
    x = Conv2D(filters=filters, kernel_size=3, strides=strides, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Activation('relu')(x)

    x = SE_block(x, r=r)
    
    x = Add()([x, short_cut])
    x = Activation('relu')(x)

    return x

def SE_Resnet18(inputs, classes):
    x = Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), padding='same', activation='relu')(inputs)
    x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
    x = block(x, filters=64, strides=1, conv_short=False)
    x = block(x, filters=64, strides=1, conv_short=False)
    x = block(x, filters=128, strides=2, conv_short=True)
    x = block(x, filters=128, strides=1, conv_short=False)
    x = block(x, filters=256, strides=2, conv_short=True)
    x = block(x, filters=256, strides=1, conv_short=False)
    x = block(x, filters=512, strides=2, conv_short=True)
    x = block(x, filters=512, strides=1, conv_short=False)
    
    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
    return x
    

if __name__ == '__main__':
    is_show_picture = False
    inputs = Input(shape=(224,224,3))
    classes = 17
    model = Model(inputs=inputs, outputs=SE_Resnet18(inputs, classes))
    model.summary()
    print(len(model.layers))
    for i in range(len(model.layers)):
        print(i, model.layers[i])
    if is_show_picture:
        plot_model(model,
           to_file='./nets_picture/SE_Resnet18.png',
           )
        print("plot_model------------------------>")
    
    
    

Guess you like

Origin blog.csdn.net/qq_42025868/article/details/122490891