SENet注意力机制——pytorch实现

论文传送门:Squeeze-and-Excitation Networks

SENet的目的:

让网络对特征层的某些重要的通道更加关注。

SENet的方法:

使用SEblock通道注意力机制模块。

SEblock的结构:

Alt

SEblock的使用:

SEblock可以作为一个即插即用的模块,理论上可以添加到任意特征层后。原文中给出其在Inception ModuleResidual Module中的使用方法。
Alt

import torch
import torch.nn as nn


class SEblock(nn.Module):  # Squeeze and Excitation block
    def __init__(self, channels, ratio=16):
        super(SEblock, self).__init__()
        channels = channels  # 输入的feature map通道数
        hidden_channels = channels // ratio  # 中间过程的通道数,原文reduction ratio设为16
        self.attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # avgpool
            nn.Conv2d(channels, hidden_channels, 1, 1, 0),  # 1x1conv,替代linear
            nn.ReLU(),  # relu
            nn.Conv2d(hidden_channels, channels, 1, 1, 0),  # 1x1conv,替代linear
            nn.Sigmoid()  # sigmoid,将输出压缩到(0,1)
        )

    def forward(self, x):
        weights = self.attn(x)  # feature map每个通道的重要性权重(0,1),对应原文的sc
        return weights * x  # 将计算得到的weights与输入的feature map相乘

猜你喜欢

转载自blog.csdn.net/Peach_____/article/details/128677412