Paper Portal:圧迫と励起のネットワーク
SENet の目的:
ネットワークがフィーチャ レイヤーのいくつかの重要なチャネルにさらに注意を払うようにします。
SENetの方式:
SEblock チャネル アテンションメカニズム モジュールを使用します。
SEblock の構造:
SEblock の使用:
SEblock はプラグ アンド プレイモジュールとして使用でき、理論的には任意の機能レイヤーに追加できます。オリジナルのテキストでは、Inception ModuleとResidual Moduleでその使用法が説明されています。
import torch
import torch.nn as nn
class SEblock(nn.Module): # Squeeze and Excitation block
def __init__(self, channels, ratio=16):
super(SEblock, self).__init__()
channels = channels # 输入的feature map通道数
hidden_channels = channels // ratio # 中间过程的通道数,原文reduction ratio设为16
self.attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1), # avgpool
nn.Conv2d(channels, hidden_channels, 1, 1, 0), # 1x1conv,替代linear
nn.ReLU(), # relu
nn.Conv2d(hidden_channels, channels, 1, 1, 0), # 1x1conv,替代linear
nn.Sigmoid() # sigmoid,将输出压缩到(0,1)
)
def forward(self, x):
weights = self.attn(x) # feature map每个通道的重要性权重(0,1),对应原文的sc
return weights * x # 将计算得到的weights与输入的feature map相乘