2018 CVPR 《Squeeze-and-Excitation Networks》 PyTorch实现

import numpy as np
import torch
from torch import nn
from torch.nn import init


# SE-attention
# 方法出处 2018 CVPR 《Squeeze-and-Excitation Networks》
# 该方法用于捕获特征图之间的关系

class SEAttention(nn.Module):
    # 模型层的初始化
    def __init__(self, channel=512, reduction=16):
        # 所有继承于nn.Module的模型都要写这句话
        super(SEAttention, self).__init__()
        # 这个AdaptiveAvgPool2d会将输入特征图的宽和高
        # 自动的池化到我们在AdaptiveAvgPool2d参数中指定的大小
        # 比如
        # m = nn.AdaptiveAvgPool2d((5, 7))
        # input = torch.randn(1, 64, 8, 9)
        # output = m(input)
        # print(output.size())
        # 最后会输出[1,64,5,7]
        # 如果指定的最后宽和高相等可以只写一个
        # 比如
        # nn.AdaptiveAvgPool2d((1,1))==nn.AdaptiveAvgPool2d(1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # 定义线性变换层
        # 计算特征图不同通道之间的相关性
        self.fc = nn.Sequential(
            # 输入的是通道的维度
            # 输出的是通道维度缩减之后的维度
            # 其中缩减系数reduction由我们自己指定
            nn.Linear(channel, channel // reduction, bias=False),
            # 激活
            # 这个inplace=True表示在前一层输出的结果之上直接进行计算
            # 而不再重新开辟内存空间
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
        self.init_weights()

    # 每层权重的初始化
    # 权重初始化可以自己定义
    # 但其实PyTorch会给我们搭建的模型自动进行初始化
    def init_weights(self):
        # 遍历当前模型的每一层
        for m in self.modules():
            # 如果是卷积层
            if isinstance(m, nn.Conv2d):
                # kaiming初始化
                init.kaiming_normal_(m.weight, mode='fan_out')
                # 偏置为0
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            # 如果是正则化层
            elif isinstance(m, nn.BatchNorm2d):
                # 权重为1
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            # 如果是线性层
            elif isinstance(m, nn.Linear):
                # 以标准差为0.001的正态分布随机初始化
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    # 前向传递建立计算图
    def forward(self, x):
        # 获取输入特征图的批大小,通道数
        b, c, _, _ = x.size()
        # 可适应池化
        # 这个操作之后
        # 相当于每一个h*w的特征图变成了一个数字
        # 先将h*w的特征图按列相加
        # 然后再按行相加
        y = self.avg_pool(x).view(b, c)
        # 计算不同通道之间的相关权重
        # 将权重转换为[b,c,1,1]维度
        y = self.fc(y).view(b, c, 1, 1)
        # 给每一个通道的特征加权
        # expand_as将y的维度变为和x相同的维度
        # 相当于Numpy的广播机制
        # 因为PyTorch底层依赖Numpy
        # PyTorch自带广播机制
        # 其实不写这个函数也行
        return x * y.expand_as(x)


if __name__ == '__main__':
    # 可以将input看作是某一个卷积层输出的特征图
    # 维度是[50,512,7,7]
    # 代表批大小是50
    # 通道数512
    # 宽,高7*7
    input = torch.randn(50, 512, 7, 7)
    # 计算各个通道特征之间的相关性
    se = SEAttention(channel=512, reduction=8)
    output = se(input)
    print(output.shape)

猜你喜欢

转载自blog.csdn.net/Talantfuck/article/details/124560375