2018 BMCV „BAM: Bottleneck Attention Module“ Pytorch-Implementierung

import torch
from torch import nn
from torch.nn import init


# 通道注意力+空间注意力的改进版
# 方法出处 2018 BMCV 《BAM: Bottleneck Attention Module》
# 展平层
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    # 将输入的x,假如它是[B,C,H,W]维度的特征图
    # 其中B代表批大小
    # C代表通道数
    # H,W代表高和宽
    # 展平层将特征图展平为[B,C*H*W]
    # 其中每一个是一个行向量
    # 方便输入到下一个全连接层中
    def forward(self, x):
        return x.view(x.shape[0], -1)


# 通道注意力
class ChannelAttention(nn.Module):
    # 网络层的初始化
    def __init__(self, channel, reduction=16, num_layers=3):
        super(ChannelAttention, self).__init__()
        # 自适应平均池化
        # 将特征图的维度,假设是[B,C,H,W]
        # 平均池化到[B,C,1,1]
        # 相当于将切片矩阵H,W
        # 先按行相加
        # 在按列相加
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        # 通道注意力中的多个全连接层的通道参数
        # gate_channels是个列表
        # 其中存放如下的数字
        # [channel,channel//reduction,channel//reduction,...(这里由num_layers控制,就是你想有多少个中间层)
        # channel]最后没有改变输入的通道数
        # 因为最后要按照通道数乘以通道权重
        gate_channels = [channel]
        gate_channels += [channel // reduction] * num_layers
        gate_channels += [channel]

        # 搭建全连接层计算通道注意力
        # Sequential以序列化的形式存储网络层
        self.ca = nn.Sequential()
        # 首先加入一个展平层,方便输入到后面的全连接层中
        self.ca.add_module('flatten', Flatten())
        # 循环,依次加入全连接层组合
        # 这个全连接组合包括
        # nn.Linear(channel,channel//reduction)或者
        # nn.Linear(channel//reduction,channel//reduction)形式的隐藏层
        # 紧接着全连接层之后的正则化层nn.BatchNorm1d
        # 因为输出的是向量所以用1d的正则化层
        # 然后是激活层
        for i in range(len(gate_channels) - 2):
            self.ca.add_module('fc%d' % i, nn.Linear(gate_channels[i], gate_channels[i + 1]))
            self.ca.add_module('bn%d' % i, nn.BatchNorm1d(gate_channels[i + 1]))
            self.ca.add_module('relu%d' % i, nn.ReLU())
        # 最后将通道数还原的全连接层
        self.ca.add_module('last_fc', nn.Linear(gate_channels[-2], gate_channels[-1]))

    # 前向传递建立计算图
    def forward(self, x):
        # 首先进行池化
        res = self.avgpool(x)
        # 然后通过全连接层
        # 计算出不同通道之间的相似性信息
        res = self.ca(res)
        # 改变通道注意力结果到与输入的特征图统一维度
        # 方便后面的相乘运算
        res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
        return res


# 空间注意力
class SpatialAttention(nn.Module):
    def __init__(self, channel, reduction=16, num_layers=3, dia_val=2):
        super(SpatialAttention, self).__init__()
        # 空间注意力中中间的卷积层
        self.sa = nn.Sequential()
        # 首先是1*1的卷积层
        # 1*1的卷积层不改变卷积层的输入的宽高
        # 只是改变输入的通道数
        self.sa.add_module('conv_reduce1',
                           nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=channel // reduction))
        self.sa.add_module('bn_reduce1', nn.BatchNorm2d(channel // reduction))
        self.sa.add_module('relu_reduce1', nn.ReLU())

        # 然后是3个3*3的卷积层
        # 这里指定了使用空洞卷积
        # 普通的卷积,卷积核之间的元素是相邻的
        # 在空洞卷积中,卷积核之间的元素会间隔指定的距离
        # 这个距离由我们自己指定
        # 因为元素之间存在空隙
        # 所以叫做空洞卷积
        # 普通卷积输出宽高的计算公式为
        # 输出的高=(输入的高+2*padding-卷积核大小)/卷积步幅+1
        # 带入参数可知这些3*3普通的卷积核没有改变输入的宽高

        # 但是这里的卷积层指定了空洞卷积
        # 计算公式为
        # 输出的高=(输入的高+2*padding-空洞距离(卷积核大小-1)-1)/卷积步幅+1
        # 带入参数
        # 输出的高=输入的高-2
        # 3次之后宽,高就变成1*1
        for i in range(num_layers):
            self.sa.add_module('conv_%d' % i, nn.Conv2d(kernel_size=3, in_channels=channel // reduction,
                                                        out_channels=channel // reduction, padding=1, dilation=dia_val))
            self.sa.add_module('bn_%d' % i, nn.BatchNorm2d(channel // reduction))
            self.sa.add_module('relu_%d' % i, nn.ReLU())

        # 最后是1*1的卷积层
        # 输出通道是1
        # 最后空间注意力维度是[B,1,1,1]
        self.sa.add_module('last_conv', nn.Conv2d(channel // reduction, 1, kernel_size=1))‘’
    #前向传递,建立计算图
    def forward(self, x):
        #计算空间注意力
        res = self.sa(x)
        #同理要转换成和输入相同的维度
        #方便和通道注意力的结果相加
        res = res.expand_as(x)
        return res


# BAM整体模型
class BAMBlock(nn.Module):
    #初始化层
    def __init__(self, channel=512, reduction=16, dia_val=2):
        super().__init__()
        #计算通道注意力
        self.ca = ChannelAttention(channel=channel, reduction=reduction)
        #计算空间注意力
        self.sa = SpatialAttention(channel=channel, reduction=reduction, dia_val=dia_val)
        #激活层
        self.sigmoid = nn.Sigmoid()
    #初始化网络层权重
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    #前向传递
    def forward(self, x):
        b, c, _, _ = x.size()
        #通道注意力结果
        sa_out = self.sa(x)
        #空间注意力结果
        ca_out = self.ca(x)
        #激活
        weight = self.sigmoid(sa_out + ca_out)
        #这里有个残差连接x+weight*x
        out = (1 + weight) * x
        return out


if __name__ == '__main__':
    # 可以将input看作一个特征图
    input = torch.randn(50, 512, 7, 7)
    # 捕获不同特征图不同通道之间的关系
    bam = BAMBlock(channel=512, reduction=16, dia_val=2)
    output = bam(input)
    print(output.shape)

おすすめ

転載: blog.csdn.net/Talantfuck/article/details/124557211
BAM