GAM アテンションのメカニズム
GAM アテンション メカニズム - 空間チャネル次元にまたがり、「グローバルな」次元を越えた相互作用を増幅するための情報を保持する「グローバルな」アテンション メカニズム。
論文アドレス:グローバル アテンション メカニズム: 情報を保持してチャネルと空間の相互作用を強化する
チャネル アテンション サブモジュール
チャネル アテンション サブモジュールは、3 次元配置を使用して情報を 3 次元で保存します。次に、2 層 MLP (多層パーセプトロン) を使用して、次元間のチャネル空間依存性を増幅します。(MLP は、BAM と同じコーダー/デコーダー構造であり、その圧縮率は r です); チャネル アテンション サブモジュールを次の図に示します。
空間注意サブモジュール
空間注意サブモジュールでは、空間情報に焦点を当てるために、空間情報融合に 2 つの畳み込み層が使用されます。BAM と同じ縮小率 r は、チャネル アテンション サブモジュールからも使用されます。同時に、最大プーリング操作は情報の使用量を削減するため、悪影響を及ぼします。機能マッピングをさらに保持するために、ここではプーリング操作が削除されます。したがって、空間注意モジュールではパラメータの数が大幅に増加する場合があります。パラメータの大幅な増加を防ぐために、ResNet50 ではチャネル シャッフルによるグループ畳み込みが採用されています。グループ畳み込みを使用しない空間アテンション サブモジュールを次の図に示します。
コード:
import torch.nn as nn
import torch
class GAM_Attention(nn.Module):
def __init__(self, in_channels, rate=4):
super(GAM_Attention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(in_channels, int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(in_channels / rate), in_channels)
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
nn.BatchNorm2d(int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
nn.BatchNorm2d(in_channels)
)
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid()
out = x * x_spatial_att
return out
if __name__ == '__main__':
x = torch.randn(1, 64, 20, 20)
b, c, h, w = x.shape
net = GAM_Attention(in_channels=c)
y = net(x)
print(y.size())