pytorch コードはアテンション メカニズムの ShuffleAttendant を実装します。

ShuffleAttention アテンション メカニズム

現在の CNN のアテンション メカニズムには主にチャネル アテンションと空間アテンションが含まれますが、現在の一部の手法 (GCNet、CBAM など) では通常この 2 つが統合されており、収束の困難さと計算負荷が大きいという問題が発生しがちです。ECANet と SGE はいくつかの最適化スキームを提案しましたが、チャネルと空間の関係を十分に活用していませんでした。したがって、著者は「異なる注意モジュールをより軽量かつ効率的な方法で融合することはできるだろうか?」という質問をします。

この問題を解決するために、著者はシャッフル アテンションを提案しました。その全体的な枠組みは下図に示されています。入力特徴が最初に g グループに分割され、次に各グループの特徴が 2 つのブランチに分割され、チャネル アテンションと空間アテンションがそれぞれ計算されることがわかります。両方のアテンションは、全結合 + シグモイドの方法を使用して計算されます。 。次に、2 つの分岐の結果が結合され、結合されて入力と同じサイズの特徴マップが得られます。最後に、シャッフル レイヤーを使用して処理します。
構造図

コードは以下のように表示されます:

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter

# https://arxiv.org/pdf/2102.00240.pdf
class ShuffleAttention(nn.Module):

    def __init__(self, channel=512,reduction=16,G=8):
        super().__init__()
        self.G=G
        self.channel=channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid=nn.Sigmoid()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)


    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        #group into subfeatures
        x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w

        #channel_split
        x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w

        #channel attention
        x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
        x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1
        x_channel=x_0*self.sigmoid(x_channel)

        #spatial attention
        x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
        x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
        x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w

        # concatenate along channel axis
        out=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,w
        out=out.contiguous().view(b,-1,h,w)

        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out

おすすめ

転載: blog.csdn.net/DM_zx/article/details/132302039