pytorch コードはアテンションメカニズムの GAM を実装します

GAM アテンションのメカニズム

GAM アテンション メカニズム - 空間チャネル次元にまたがり、「グローバルな」次元を越えた相互作用を増幅するための情報を保持する「グローバルな」アテンション メカニズム。

論文アドレス:グローバル アテンション メカニズム: 情報を保持してチャネルと空間の相互作用を強化する
GAMの構造模式図

チャネル アテンション サブモジュール

チャネル アテンション サブモジュールは、3 次元配置を使用して情報を 3 次元で保存します。次に、2 層 MLP (多層パーセプトロン) を使用して、次元間のチャネル空間依存性を増幅します。(MLP は、BAM と同じコーダー/デコーダー構造であり、その圧縮率は r です); チャネル アテンション サブモジュールを次の図に示します。
チャネル アテンション サブモジュールの構造図

空間注意サブモジュール

空間注意サブモジュールでは、空間情報に焦点を当てるために、空間情報融合に 2 つの畳み込み層が使用されます。BAM と同じ縮小率 r は、チャネル アテンション サブモジュールからも使用されます。同時に、最大プーリング操作は情報の使用量を削減するため、悪影響を及ぼします。機能マッピングをさらに保持するために、ここではプーリング操作が削除されます。したがって、空間注意モジュールではパラメータの数が大幅に増加する場合があります。パラメータの大幅な増加を防ぐために、ResNet50 ではチャネル シャッフルによるグループ畳み込みが採用されています。グループ畳み込みを使用しない空間アテンション サブモジュールを次の図に示します。
空間注意サブモジュールの構造図

コード:

import torch.nn as nn
import torch
 
class GAM_Attention(nn.Module):
    def __init__(self, in_channels, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels)
        )
 
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
 
        x = x * x_channel_att
 
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att
 
        return out
 
if __name__ == '__main__':
    x = torch.randn(1, 64, 20, 20)
    b, c, h, w = x.shape
    net = GAM_Attention(in_channels=c)
    y = net(x)
    print(y.size())

おすすめ

転載: blog.csdn.net/DM_zx/article/details/132707926