SENet アテンション メカニズム - pytorch の実装

Paper Portal:圧迫と励起のネットワーク

SENet の目的:

ネットワークがフィーチャ レイヤーのいくつかの重要なチャネルにさらに注意を払うようにします。

SENetの方式:

SEblock チャネル アテンションメカニズム モジュールを使用します

SEblock の構造:

オルタナティブ

SEblock の使用:

SEblock はプラグ アンド プレイモジュールとして使用でき、理論的には任意の機能レイヤーに追加できます。オリジナルのテキストでは、Inception ModuleResidual Moduleでその使用法が説明されています。
オルタナティブ

import torch
import torch.nn as nn


class SEblock(nn.Module):  # Squeeze and Excitation block
    def __init__(self, channels, ratio=16):
        super(SEblock, self).__init__()
        channels = channels  # 输入的feature map通道数
        hidden_channels = channels // ratio  # 中间过程的通道数,原文reduction ratio设为16
        self.attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # avgpool
            nn.Conv2d(channels, hidden_channels, 1, 1, 0),  # 1x1conv,替代linear
            nn.ReLU(),  # relu
            nn.Conv2d(hidden_channels, channels, 1, 1, 0),  # 1x1conv,替代linear
            nn.Sigmoid()  # sigmoid,将输出压缩到(0,1)
        )

    def forward(self, x):
        weights = self.attn(x)  # feature map每个通道的重要性权重(0,1),对应原文的sc
        return weights * x  # 将计算得到的weights与输入的feature map相乘

おすすめ

転載: blog.csdn.net/Peach_____/article/details/128677412
おすすめ