SGE アテンション・オブ・アテンション・メカニズム

空間グループごとの強化: 畳み込みネットワークにおけるセマンティック特徴学習の改善

ペーパーリンク

論文:空間グループごとの強化: 畳み込みネットワークにおけるセマンティック特徴学習の改善

モデル構造

ここに画像の説明を挿入

論文の主な内容

畳み込みニューラル ネットワーク (CNN) は、さまざまなレベルおよび部分で意味論的なサブ特徴を収集することによって、複雑なオブジェクトの特徴表現を生成します。これらのサブ特徴は、通常、グループ化された形式で各層の特徴ベクトルに分散され、さまざまな意味論的エンティティを表すことができます。ただし、これらのサブ機能のアクティブ化は、同様のパターンやノイズの多い背景によって空間的に影響を受けることが多く、誤った位置特定と認識につながります。この論文は、各意味論的グループ内の各空間位置に対する注意係数を生成することで各サブ特徴の重要性を調整できる空間グループ強化(SGE)モジュールを提案します。これにより、個々のグループが自律的に強化表現を学習し、可能性を抑制できます。ノイズ。アテンションファクターは、各グループ内のグローバル特徴記述子とローカル特徴記述子の間の類似性によってのみ導かれるため、SGE モジュールの設計は非常に軽量で、追加のパラメーターや計算はほとんどありません。

import numpy as np
import torch
from torch import nn
from torch.nn import init



class SpatialGroupEnhance(nn.Module):

    def __init__(self, groups):
        super().__init__()
        self.groups=groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.weight=nn.Parameter(torch.zeros(1,groups,1,1))
        self.bias=nn.Parameter(torch.zeros(1,groups,1,1))
        self.sig=nn.Sigmoid()
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, h,w=x.shape
        x=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,w
        xn=x*self.avg_pool(x) #bs*g,dim//g,h,w
        xn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,w
        t=xn.view(b*self.groups,-1) #bs*g,h*w

        t=t-t.mean(dim=1,keepdim=True) #bs*g,h*w
        std=t.std(dim=1,keepdim=True)+1e-5
        t=t/std #bs*g,h*w
        t=t.view(b,self.groups,h,w) #bs,g,h*w
        
        t=t*self.weight+self.bias #bs,g,h*w
        t=t.view(b*self.groups,1,h,w) #bs*g,1,h*w
        x=x*self.sig(t)
        x=x.view(b,c,h,w)

        return x 


if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    sge = SpatialGroupEnhance(groups=8)
    output=sge(input)
    print(output.shape)

コード分​​析

このモジュールは、CNN の特徴マップの表現力を強化し、パフォーマンスを向上させることができます。
このコード ブロックでは、2 つの変数 b と c が定義されており、それぞれ入力 x のバッチ サイズとチャネル数を表します。

次に、リシェイプ操作が実行されて、各チャネルが self.groups グループに分割されます。最初の行では、x は変形後のサイズ (b*groups,dim//groups,h,w) のテンソルになります。

次に、グループに分割された特徴マップはxn = x 1 HW ∑ x xn=x\frac{1}{HW}\sum x を渡します。× n=バツHW _1xメソッドは重み付けされた特徴マップを生成します。H*W は特徴マップ内のピクセル数であり、avg_pool 演算により出力特徴マップの形状が一貫していることが保証されます。

次に、xn xnを実行します。x nの重み付け演算

その後、重み付きテンソルtttは標準化されます。つまり、平均が各要素から減算され、分散で除算されます。

この段階では、t はグループの全体平均 (ローリング平均) および標準偏差 (ローリング標準) として扱われます。グループ内の各値は、グループのローリング平均から減算され、グループのローリング標準で除算されます。正規化の結果はテンソル t に格納されます。

そして、活性化関数の計算、つまり t * self.weight + self.bias が実行されます。ここで、self.weight と self.bias は学習可能なパラメーターです。

次に、x * t のシグモイド演算が実行され、結果が再度グループ化されます。

正規化された結果に元の入力「x」(コードの 3 行目)を乗算し、結果を最終出力に再形成します。

最後に、入力からのランダムなテンソルを使用してモデルをテストし、最終的な出力形状を出力します。

おすすめ

転載: blog.csdn.net/qq_38915354/article/details/130552516