pytorch コード実装によるアテンション メカニズム (SEnet、CBAM)

CNN でのアテンション メカニズムの使用は、ネットワークのパフォーマンスを向上させる効果的な方法であり、その中心的な焦点は、より注意が必要な場所にネットワークに注意を向けさせることです。アテンション メカニズムを使用して、ネットワークが特定の領域に適応的に集中できるようにします。主な実装は重み付けされた値です。コードでは、チャネルまたはピクセルに重み付けすることによってアテンション メカニズムも実装しています。

一般に、注意メカニズムは、チャネル注意と空間注意、または両方の組み合わせに分けられます。

1. チャンネルへの注目

CNN の特徴マップには、幅 (W)、高さ (H)、チャネル数 (C) が含まれることがよくあります。以下に示すように:

 異なるチャネルは異なる特徴情報を抽出します。たとえば、最初のチャネルは水平方向の特徴を抽出し、2 番目のチャネルは垂直方向の特徴を抽出します。タスクが垂直方向の特徴にさらに注意を払う場合、この特徴を増幅するために 2 番目のチャネルにより高い重みを付ける必要があります。

2. 空間への注意

空間的注意 特徴マップでは、同じチャネル内のすべての領域が同じように重要であるわけではありません。たとえば、下の図では、青でマークされた位置が焦点となる位置です。空間的注意メカニズムを使用して、対応するピクセルの重みを高くして、この領域に焦点を当てる機能を実現します。

3. アテンションメカニズム(SEnet、CBAM)の紹介

ここでは主に SEnet と CBAM のアテンション メカニズムと関連する pytorch コードの実装を紹介します。

1、SEネット

(1) 論文の出典

SEnet は論文「Squeeze-and-Excitation Networks」から来ています。

論文のリンクはhttps://arxiv.org/abs/1709.01507です。

(2) SEnetの構造構成

SEnet の構造は次の図に示されており、このモデルはチャネル アテンションを実現できます。

入力フィーチャのサイズは W×H×C で、SEnet の実装手順には次の手順が含まれます。

1. グローバル平均プーリング: グローバル平均プーリングは各チャネルで実行され、1×1×C のサイズの特徴が得られます。

2. スクイーズ操作: グローバル平均プーリング後の特徴ベクトルは、全結合層 (通常は多層パーセプトロン) を通じてより小さい次元にマッピングされます。このプロセスは「スクイーズ」と呼ばれ、チャネルのグローバル表現を学習するために使用されます。一般に、この全結合層の出力次元は、圧縮率を決定するハイパーパラメータによって制御できます。

3. 励起操作: 「スクイーズ」操作の出力を完全に接続された層に渡し、次元を元のチャネル数に戻します。次に、この全結合層の出力が活性化関数 (シグモイド関数など) によって活性化されて、チャネル アテンション ウェイトが生成されます。

4. 特徴量を再スケールする: チャネル アテンション ウェイトを元の特徴量と乗算し、そのアテンション ウェイトを元の特徴量に適用して、特徴マップ内の各チャネルの重み付けを変更します。このプロセスは、チャネル アテンション ウェイトを元の特徴マップと同じ形状に拡張し、チャネルごとに乗算することで実現できます。

実装コードは次のとおりです。コードは gpt によって生成され、非常に使いやすいです。

import torch
import torch.nn as nn

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(channels // reduction, channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        out = self.avg_pool(x).view(b, c)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out).view(b, c, 1, 1)
        out = x * out.expand_as(x)
        return out


class SENet(nn.Module):
    def __init__(self, num_classes=10):
        super(SENet, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.se_block = SEBlock(32)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 32 * 32, num_classes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = nn.ReLU(inplace=True)(out)
        out = self.se_block(out)
        out = self.flatten(out)
        out = self.fc(out)
        return out


# 创建模型实例
model = SENet()

 2、CBAM

(1) 論文の出典

CBAM は、論文「CBAM: Convolutional Block Attention Module」に由来しています。

元のリンク:

https://arxiv.org/pdf/1807.06521.pdf

(2) CBAMの構造構成

CBAM アテンション メカニズムは、チャネル アテンション モジュールと空間アテンション モジュールで構成されます。主に次の手順が含まれます。

1. 共有グローバル平均プーリング: 入力特徴マップに対して 2 つのグローバル平均プーリング操作が実行され、それぞれ空間アテンションとチャネル アテンションを計算するために使用されます。グローバル平均プーリングは、特徴マップ全体の統計を取得できます。

2. 空間的注意: グローバル平均プーリングの後、2 つの異なる処理ブランチがフィーチャに対して実行されます。1 つのブランチは、特徴マップ内のさまざまな位置の重要性を学習するために、1x1 畳み込み層を介して空間アテンションの重みを生成します。もう 1 つの分岐は、1x1 畳み込み層を通じて空間アテンション スケーリング係数を生成することです。これは、特徴マップ内の各チャネルの重みを調整するために使用されます。これら 2 つのブランチの出力は要素ごとに乗算されて、最終的な空間アテンションが得られます。

3. チャネル アテンション: グローバル平均プーリングの後、特徴に対して 2 つの異なる処理ブランチが実行されます。1 つのブランチは、各チャネルの重要性を学習するために、完全に接続された層を通じてチャネル アテンション ウェイトを生成します。もう 1 つの分岐は、全結合層を通じてチャネル アテンション スケーリング係数を生成することです。これは、特徴マップ内の各チャネルの重みを調整するために使用されます。これら 2 つのブランチの出力は要素ごとに乗算されて、最終的なチャネル アテンションが得られます。

4. 特徴の再調整: 空間アテンションとチャネル アテンションがそれぞれ元の特徴マップに適用されます。まず、空間注意重みが拡張され、特徴マップと同じ形状になるように繰り返されます。次に、チャネル アテンション ウェイトは同じ次元を持つように拡張されます。最後に、この 2 つが要素ごとに乗算され、元の特徴マップにアテンションの重みが適用されて、最終的な再調整された特徴が得られます。

実装コードは次のとおりです。

import torch
import torch.nn as nn

class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super(ChannelAttention, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        out = x * out
        return out


class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()

        self.conv = nn.Conv2d(2, 1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        out = self.sigmoid(out)
        out = x * out
        return out


class CBAM(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAM, self).__init__()

        self.channel_att = ChannelAttention(channels, reduction)
        self.spatial_att = SpatialAttention()

    def forward(self, x):
        out = self.channel_att(x)
        out = self.spatial_att(out)
        return out


# 创建模型实例
model = CBAM(channels=64, reduction=16)

 

おすすめ

転載: blog.csdn.net/m0_45267220/article/details/130687607