【AAAI2023】リニアトランスフォーマーによるヘッドフリー軽量セマンティックセグメンテーション

線形トランスフォーマーを使用したヘッドフリー軽量セマンティック セグメンテーション、AAAI2023

解釈: Ali チームの新作 | AFFormer: 画像周波数情報を使用して軽量トランスフォーマー セマンティック セグメンテーション アーキテクチャを構築 (qq.com)

論文: https://arxiv.org/abs/2301.04648

コード: GitHub - dongbo811/AFFormer

ガイド

この論文では、Adaptive Frequency Transformer (AFFormer) と呼ばれるセマンティック セグメンテーション アーキテクチャを提案します。AFFormer は並列アーキテクチャを採用し、プロトタイプ表現を特定の学習可能なローカル記述として活用します。これにより、デコーダが置き換えられ、高解像度の特徴に関する豊富な画像セマンティクスが保持されます。 デコーダを削除するとほとんどの推論計算を圧縮できますが、並列アーキテクチャの精度は依然として計算リソースが少ないため制限されます。したがって、計算コストをさらに節約するために、ピクセル埋め込みとプロトタイプ表現に異種オペレーター (CNN および Vision Transformer) が採用されています。さらに、空間領域の観点から Vision Transformer の複雑さを線形化することは非常に困難です。セマンティック セグメンテーションは周波数情報に非常に敏感であるため、この論文では、標準的なセルフ アテンションの O(n^2) の複雑さを置き換えるために、複雑さ O(n) の適応周波数フィルターを備えた軽量のプロトタイプ学習ブロックを構築します。 

広く使用されているデータセットに対する広範な実験により、AFFormer が 3M パラメーターを維持しながら優れた精度を達成することが示されています。ADE20K データセットでは、AFFormer は 41.8 mIoU および 4.6 GFLOP を達成します。これはSegformer よりも 4.4 mIoU 高く、GFLOPS は 45% 削減されますCityscapes データセットでは、AFFormer は 78.7 mIoU と 34.4 GFLOP を達成し、これはSegformer より 2.5 mIoU 高く、GFLOPS は 72.5% 削減されます

序章

AFFormer は、ADE20K および Cityscapes データセットで大幅に低い FLOP でより高い精度を実現します。

セマンティック セグメンテーションは、画像をサブ領域 (ピクセルのセット) に分割するタスクであり、ピクセル レベルの密な予測とマルチクラス表現という 2 つの独自の機能を備えており、画像セマンティクスのグローバルな誘導が必要です。これまでのセマンティック セグメンテーション手法は、主に、バックボーンとして分類ネットワークを使用してマルチスケール特徴を抽出し、マルチスケール特徴間の関係を確立する複雑なデコーダ ヘッドを設計することに重点を置いていました。ただし、パラメータのサイズが膨大になり、計算コストが高くなります。この固有の設計により、その開発が制限され、その応用が妨げられます。したがって、この論文では「セマンティック セグメンテーションは画像分類と同じくらい簡単ですか?」という質問をしています。

ビジュアル トランスフォーマー (ViT) には大きな可能性がありますが、パフォーマンスとメモリ使用量のバランスをとるという課題に直面しています。既存の方法は、トークンまたはスライディング ウィンドウの数を減らすことでこの状況を軽減しますが、計算の複雑さの軽減には限界があり、セグメンテーション タスクのグローバルまたはローカルのセマンティクスを損なうことさえあります。同時に、セマンティック セグメンテーションは基礎研究分野として幅広い応用シナリオがあり、さまざまな解像度の画像を処理する必要があります。上の図に示されているように、効率的な Segformer は PSPNet や DeepLabV3+ と比較して大きな進歩を遂げましたが、依然として高解像度化による膨大な計算負荷に直面しています。この論文は別の疑問を提起しています。効率的で柔軟な Transformer ネットワークは、セマンティック セグメンテーションの超低コンピューティング シナリオ向けに設計できるでしょうか?

この論文では、Adaptive Frequency Transformer (AFFormer) と呼ばれる軽量のセマンティック セグメンテーション アーキテクチャを提案します。AFFormer は、並列アーキテクチャを採用して、デコーダを特定の学習可能なローカル記述としてプロトタイプの表現に置き換え、高解像度の特徴に関する豊富な画像セマンティクスを保持します。さらに、ピクセル埋め込み特徴と局所記述特徴を処理するために異種オペレーターが採用され、計算コストがさらに節約されます。プロトタイプ学習 (PL) と呼ばれるトランスフォーマーベースのモジュールはプロトタイプ表現を学習するために使用されますが、ピクセル記述子 (PD) と呼ばれる畳み込みベースのモジュールはピクセル埋め込み特徴と学習されたプロトタイプ表現を入力として受け取り、それらをフルピクセルに変換して戻します。 -解決セマンティクス。

ただし、空間領域の観点から ViT の複雑さを線形化することは依然として非常に困難です。この論文では、セマンティック セグメンテーションは周波数情報にも非常に敏感であることがわかりました。したがって、複雑さ O(n) の軽量の適応周波数フィルターが構築され、標準的な自己注意 O(n^2) を置き換えるプロトタイプとして学習されます。このモジュールのコアは、周波数類似性カーネル、動的ローパス フィルターおよびハイパス フィルターで構成されており、それぞれ重要な周波数成分を強調し、周波数を動的にフィルターするという観点から、セマンティック セグメンテーションに有益な周波数情報を取得します。最後に、高周波と低周波の抽出および拡張モジュールで重みを共有することにより、計算コストがさらに削減されます。簡略化された深い畳み込み層もフィードフォワード ネットワーク (FFN) 層に埋め込まれており、融合効果を強化し、2 つの行列変換のサイズを削減します。

並列異種アーキテクチャと適応周波数フィルターの助けを借りて、最高のパフォーマンスを達成するために単一スケールの特徴の分類層 (CLS) として 1 つの畳み込み層のみが使用され、セマンティック セグメンテーションが画像分類と同じくらい簡単になります。 ADE20K、Cityscapes、COCO など、広く使用されている 3 つのデータセットで AFFormer の利点を実証します。AFFormer は、3M パラメーターのみを使用することで、最先端の軽量メソッドを大幅に上回ります。

方法

適応型周波数変換器。まず、並列ヘテロジニアス ネットワークの全体構造を示します。具体的には、まず、パッチ埋め込み後の特徴量Fをクラスタリングしてプロトタイプ特徴量Gを取得し、2つの異種オペレータを含む並列ネットワーク構造を構築する。Transformer ベースのモジュールは、G の好ましい周波数成分を捕捉してプロトタイプ表現 G' を取得するプロトタイプとして学習されます。最後に、G' が CNN ベースのピクセル記述子によって復元され、次のステージの F' が取得されます。

並列異種アーキテクチャ

セマンティック デコーダは、エンコーダによって取得された画像セマンティクスを各ピクセルに伝播し、ダウンサンプリングで失われた詳細を復元します。プロトタイプのセマンティクスを使用してピクセルのセマンティクス情報を記述するための新しい戦略が提案されています。各ステージでは、特徴 F ∈ R^ (H×W×C) が与えられると、最初に画像のプロトタイプとしてグリッド G ∈ R^(h×w×C) を初期化します。ここで、G の各点は次のようになります。ローカル クラスター センターの場合、初期状態には周囲の領域に関する情報のみが含まれます。ここでは、1 × C ベクトルを使用して、各点のローカルな意味情報を表します。特定のピクセルごとに、周囲のピクセルのセマンティクスが一貫していないため、各クラスターの中心間に重複するセマンティクスが存在します。クラスター中心は、対応する領域 α^2 で重み付け初期化され、各クラスター中心の初期化は次のように表されます。

 n=α×αの場合、wiはxiの重みを表し、αは3に設定されます。目的は、特徴 F を直接更新するのではなく、グリッド G 内の各クラスター中心 s を更新することです。h×w < H×W であるため、これにより計算が大幅に簡素化されます。

ここでは、Transformer ベースのモジュールがプロトタイプとして使用され、合計 L 層を含む各クラスターの中心を更新する方法を学習します。更新された中心は G'(s) と呼ばれます。更新されたクラスター中心ごとに、ピクセル記述子から復元されます。F'i が回復された特徴を表すものとします。これには、F からの豊富なピクセル セマンティクスだけでなく、クラスター センター G'(s) によって収集されたプロトタイプ セマンティクスも含まれます。 クラスターの中心は周囲のピクセルのセマンティクスを集約し、その結果局所的な詳細が失われるため、PD は最初にピクセル セマンティクスを使用して F で局所的な詳細をモデル化します。具体的には、F を低次元空間に投影してピクセル間の局所的な関係を確立し、各ローカル ブロックが明確な境界を維持するようにします。次に、G'(s) が F に埋め込まれ、双一次補間によって元の空間特徴 F' に復元されます。最後に、線状投影層を介して統合されます。

適応周波数フィルターによるプロトタイプ学習の動機付け

セマンティック セグメンテーションに対するさまざまな周波数成分の影響。

セマンティック セグメンテーションは非常に複雑なピクセル レベルの分類タスクであり、カテゴリの混乱が生じやすいです。周波数表現は、カテゴリの違いを学習するための新しいパラダイムとして機能し、人間の視覚では無視できる情報を掘り出すことができます上の図に示すように、人間は、ほとんどの周波数成分がフィルタリングされない限り、周波数情報の除去に対して耐性があります。ただし、このモデルは周波数情報の削除に非常に敏感であり、たとえ少量の情報でも削除すると、パフォーマンスが大幅に低下する可能性があります。これは、モデルの場合、より多くの頻度情報をマイニングすることでカテゴリ間の差異が強調され、各カテゴリ間の境界がより明確になり、それによってセマンティック セグメンテーションの効果が向上することがわかります。

特徴 F には豊富な周波数特徴が含まれているため、グリッド G の各クラスター中心もこれらの周波数情報を収集します。異なる周波数の特徴を抽出するために、以前の研究では、フーリエ変換と逆フーリエ変換に基づく方法が提案されました。ただし、このアプローチでは追加の計算オーバーヘッドが発生するため、多くのハードウェアでは使用できません。したがって、この論文では、空間領域で重要な高周波および低周波の特徴を直接キャプチャするための、ビジュアル Transformer ベースの適応周波数フィルタリング ブロックを提案します。 そのコアコンポーネントは上の図に示されており、式は次のように定義されます。

上の式は、適応周波数フィルタリング ブロックの動作を定義します。このうち、D^(fc)h、D^(lf)m(X)、D^(hf)n(X)はそれぞれ、Hグループの周波数類似性カーネル、Mグループの動的ローパスフィルタ、 N グループのダイナミック ハイパス フィルター。||・|| は連結を意味します。注目すべきことに、これらの操作は並列構造を採用して計算コストをさらに削減しており、これは重みを共有することによって実現されています。

周波数類似性カーネル (FSK)

G にはさまざまな周波数成分が分散されており、その目的は、意味解析に役立つ重要な成分を選択して強化することですこの目的のために、周波数類似性カーネル モジュールが設計されています。特徴 X ∈ R^((hw)×C) が与えられると、相対位置は畳み込み層を介して G 上にエンコードされます。まず、固定サイズの類似度カーネル A∈R^(C/H×C/H) を使用して異なる周波数成分間の対応を表現し、類似度カーネルをクエリすることで重要な周波数成分を選択します。周波数成分のキー K と値 V は線形層によって計算され、キーはソフトマックス演算によって周波数成分間で正規化されます。各コンポーネントは、次のように計算される類似性カーネル Ai,j を統合します。

 このうち、kiはKのi番目の周波数成分、vjはVのj番目の周波数成分を表します。入力 X も線形層によってクエリ Q に変換され、出力は固定サイズの類似性カーネルの対話型取得コンポーネントによって強化されます。 

ダイナミックローパスフィルター(DLF)

低周波成分は絶対画像のエネルギーの大部分を占め、意味情報の大部分を表します。ローパス フィルターは、カットオフ周波数より下の信号を通過させますが、カットオフ周波数より上の信号はブロックされます。したがって、一般的な平均プーリングがローパス フィルターとして使用されますただし、カットオフ周波数は画像ごとに異なります。この目的を達成するために、異なるカーネルとストライドが複数のグループで制御され、動的なローパス フィルターが生成されますグループ m の場合:

ここで、Λ k × k は、カーネル サイズ k × k の深さ方向の畳み込み層を示します。さらに、クエリと高周波特徴のアダマール積を使用して、セグメンテーションのノイズであるオブジェクト内の高周波を抑制しますFFN はキャプチャした周波数情報を融合するのに役立ちますが、計算量が多く、軽量設計では通常無視されます。隠れ層の次元は、次元圧縮によって失われた機能を補うために畳み込み層を導入することによって削減されます。

議論

周波数類似性カーネルの場合、計算量は O(hwC^2) です。各動的ハイパス フィルターの計算量は O(hwCk^2) で、周波数類似性カーネルの計算量よりもはるかに小さくなります。動的ローパス フィルターはグループごとの適応平均プーリングによって実装されるため、その計算量は約 O(hwC) です。したがって、モジュールの計算の複雑さは解像度に比例して増加するため、高解像度のセマンティック セグメンテーションには有益です。

実験

 

 

キーコード

アフォーマー.py

# https://github.com/dongbo811/AFFormer/blob/main/tools/afformer.py

class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, size):
        H, W = size
        x = self.fc1(x)
        x = self.act(x + self.dwconv(x, H, W))
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Conv2d_BN(nn.Module):
    """Convolution with BN module."""

    def __init__(
            self,
            in_ch,
            out_ch,
            kernel_size=1,
            stride=1,
            pad=0,
            dilation=1,
            groups=1,
            bn_weight_init=1,
            norm_layer=nn.BatchNorm2d,
            act_layer=None,
    ):
        super().__init__()

        self.conv = torch.nn.Conv2d(in_ch,
                                    out_ch,
                                    kernel_size,
                                    stride,
                                    pad,
                                    dilation,
                                    groups,
                                    bias=False)
        self.bn = norm_layer(out_ch)
        torch.nn.init.constant_(self.bn.weight, bn_weight_init)
        torch.nn.init.constant_(self.bn.bias, 0)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Note that there is no bias due to BN
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out))

        self.act_layer = act_layer() if act_layer is not None else nn.Identity(
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act_layer(x)

        return x


class DWConv2d_BN(nn.Module):


    def __init__(
            self,
            in_ch,
            out_ch,
            kernel_size=1,
            stride=1,
            norm_layer=nn.BatchNorm2d,
            act_layer=nn.Hardswish,
            bn_weight_init=1,
    ):
        super().__init__()

        # dw
        self.dwconv = nn.Conv2d(
            in_ch,
            out_ch,
            kernel_size,
            stride,
            (kernel_size - 1) // 2,
            groups=out_ch,
            bias=False,
        )
        # pw-linear
        self.pwconv = nn.Conv2d(out_ch, out_ch, 1, 1, 0, bias=False)
        self.bn = norm_layer(out_ch)
        self.act = act_layer() if act_layer is not None else nn.Identity()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(bn_weight_init)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.dwconv(x)
        x = self.pwconv(x)
        x = self.bn(x)
        x = self.act(x)

        return x


class DWCPatchEmbed(nn.Module):

    def __init__(self,
                 in_chans=3,
                 embed_dim=768,
                 patch_size=16,
                 stride=1,
                 act_layer=nn.Hardswish):
        super().__init__()

        self.patch_conv = DWConv2d_BN(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=stride, 
            act_layer=act_layer,
        )

    def forward(self, x):

        x = self.patch_conv(x)

        return x


class Patch_Embed_stage(nn.Module):

    def __init__(self, embed_dim, num_path=4, isPool=False, stage=0):
        super(Patch_Embed_stage, self).__init__()

        if stage == 3:
            self.patch_embeds = nn.ModuleList([
                DWCPatchEmbed(
                    in_chans=embed_dim,
                    embed_dim=embed_dim,
                    patch_size=3,
                    stride=4 if (isPool and idx == 0) or (stage > 1 and idx == 1) else 1,
                ) for idx in range(num_path + 1)
            ])
        else:

            self.patch_embeds = nn.ModuleList([
                DWCPatchEmbed(
                    in_chans=embed_dim,
                    embed_dim=embed_dim,
                    patch_size=3,
                    stride=2 if (isPool and idx == 0) or (stage > 1 and idx == 1) else 1,
                ) for idx in range(num_path + 1)
            ])

    def forward(self, x):
        att_inputs = []
        for pe in self.patch_embeds:
            x = pe(x)

            att_inputs.append(x)

        return att_inputs


class ConvPosEnc(nn.Module):
    def __init__(self, dim, k=3):
        super(ConvPosEnc, self).__init__()

        self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)

    def forward(self, x, size):
        B, N, C = x.shape
        H, W = size

        feat = x.transpose(1, 2).view(B, C, H, W)
        x = self.proj(feat) + feat
        x = x.flatten(2).transpose(1, 2)

        return x

class LowPassModule(nn.Module):
    def __init__(self, in_channel, sizes=(1, 2, 3, 6)):
        super().__init__()
        self.stages = []
        self.stages = nn.ModuleList([self._make_stage(size) for size in sizes])
        self.relu = nn.ReLU()
        ch =  in_channel // 4
        self.channel_splits = [ch, ch, ch, ch]

    def _make_stage(self, size):
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
        return nn.Sequential(prior)

    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)
        feats = torch.split(feats, self.channel_splits, dim=1)
        priors = [F.upsample(input=self.stages[i](feats[i]), size=(h, w), mode='bilinear') for i in range(4)]
        bottle = torch.cat(priors, 1)
        
        return self.relu(bottle)
    
    
class FilterModule(nn.Module):
    def __init__(self, Ch, h, window):

        super().__init__()

        self.conv_list = nn.ModuleList()
        self.head_splits = []
        for cur_window, cur_head_split in window.items():
            dilation = 1  # Use dilation=1 at default.
            padding_size = (cur_window + (cur_window - 1) *
                            (dilation - 1)) // 2
            cur_conv = nn.Conv2d(
                cur_head_split * Ch,
                cur_head_split * Ch,
                kernel_size=(cur_window, cur_window),
                padding=(padding_size, padding_size),
                dilation=(dilation, dilation),
                groups=cur_head_split * Ch,
            )
            self.conv_list.append(cur_conv)
            self.head_splits.append(cur_head_split)
        self.channel_splits = [x * Ch for x in self.head_splits]
        self.LP = LowPassModule(Ch * h)

    def forward(self, q, v, size):
        B, h, N, Ch = q.shape
        H, W = size

        # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W].
        v_img = rearrange(v, "B h (H W) Ch -> B (h Ch) H W", H=H, W=W)
        LP = self.LP(v_img)
        # Split according to channels.
        v_img_list = torch.split(v_img, self.channel_splits, dim=1)
        HP_list = [
            conv(x) for conv, x in zip(self.conv_list, v_img_list)
        ]
        HP = torch.cat(HP_list, dim=1)
        # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].
        HP = rearrange(HP, "B (h Ch) H W -> B h (H W) Ch", h=h)
        LP = rearrange(LP, "B (h Ch) H W -> B h (H W) Ch", h=h)

        dynamic_filters = q * HP + LP
        return dynamic_filters


class Frequency_FilterModule(nn.Module):

    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=False,
            qk_scale=None,
            attn_drop=0.0,
            proj_drop=0.0,
            shared_crpe=None,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # Shared convolutional relative position encoding.
        self.crpe = shared_crpe

    def forward(self, x, size):
        B, N, C = x.shape

        # Generate Q, K, V.
        qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads,
                                   C // self.num_heads).permute(2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Factorized attention.
        k_softmax = k.softmax(dim=2)
        k_softmax_T_dot_v = einsum("b h n k, b h n v -> b h k v", k_softmax, v)
        factor_att = einsum("b h n k, b h k v -> b h n v", q,
                            k_softmax_T_dot_v)

        # Convolutional relative position encoding.
        crpe = self.crpe(q, v, size=size)

        # Merge and reshape.
        x = self.scale * factor_att + crpe
        x = x.transpose(1, 2).reshape(B, N, C)

        # Output projection.
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

class MHCABlock(nn.Module):
    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=3,
            drop_path=0.0,
            qkv_bias=True,
            qk_scale=None,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            shared_cpe=None,
            shared_crpe=None,
    ):
        super().__init__()

        self.cpe = shared_cpe
        self.crpe = shared_crpe
        self.factoratt_crpe = Frequency_FilterModule(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            shared_crpe=shared_crpe,
        )
        self.mlp = Mlp(in_features=dim, hidden_features=dim * mlp_ratio)
        self.drop_path = DropPath(
            drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)

    def forward(self, x, size):
        if self.cpe is not None:
            x = self.cpe(x, size)
        cur = self.norm1(x)
        x = x + self.drop_path(self.factoratt_crpe(cur, size))

        cur = self.norm2(x)
        x = x + self.drop_path(self.mlp(cur, size))
        return x


class MHCAEncoder(nn.Module):
    def __init__(
            self,
            dim,
            num_layers=1,
            num_heads=8,
            mlp_ratio=3,
            drop_path_list=[],
            qk_scale=None,
            crpe_window={
                3: 2,
                5: 3,
                7: 3
            },
    ):
        super().__init__()

        self.num_layers = num_layers
        self.cpe = ConvPosEnc(dim, k=3)
        self.crpe = FilterModule(Ch=dim // num_heads,
                                  h=num_heads,
                                  window=crpe_window)
        self.MHCA_layers = nn.ModuleList([
            MHCABlock(
                dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                drop_path=drop_path_list[idx],
                qk_scale=qk_scale,
                shared_cpe=self.cpe,
                shared_crpe=self.crpe,
            ) for idx in range(self.num_layers)
        ])

    def forward(self, x, size):
        H, W = size
        B = x.shape[0]
        for layer in self.MHCA_layers:
            x = layer(x, (H, W))

        # return x's shape : [B, N, C] -> [B, C, H, W]
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        return x


class Restore(nn.Module):

    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.Hardswish,
            norm_layer=nn.BatchNorm2d,
    ):
        super().__init__()

        out_features = out_features or in_features
        hidden_features = in_features // 2
        self.conv1 = Conv2d_BN(in_features,
                               hidden_features,
                               act_layer=act_layer)
        self.dwconv = nn.Conv2d(
            hidden_features,
            hidden_features,
            3,
            1,
            1,
            bias=False,
            groups=hidden_features,
        )
        self.norm = norm_layer(hidden_features)
        self.act = act_layer()
        self.conv2 = Conv2d_BN(hidden_features, out_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()

    def forward(self, x):
        identity = x
        feat = self.conv1(x)
        feat = self.dwconv(feat)
        feat = self.norm(feat)
        feat = self.act(feat)
        feat = self.conv2(feat)

        return identity + feat


class MHCA_stage(nn.Module):
    def __init__(
            self,
            embed_dim,
            out_embed_dim,
            num_layers=1,
            num_heads=8,
            mlp_ratio=3,
            num_path=4,
            drop_path_list=[],
            id_stage=0,
    ):
        super().__init__()



        self.Restore = Restore(in_features=embed_dim, out_features=embed_dim)
        if id_stage > 0:
            self.aggregate = Conv2d_BN(embed_dim * (num_path),
                                       out_embed_dim,
                                       act_layer=nn.Hardswish)
            self.mhca_blks = nn.ModuleList([
                MHCAEncoder(
                    embed_dim,
                    num_layers,
                    num_heads,
                    mlp_ratio,
                    drop_path_list=drop_path_list,
                ) for _ in range(num_path)
            ])
        else:
            self.aggregate = Conv2d_BN(embed_dim * (num_path),
                                       out_embed_dim,
                                       act_layer=nn.Hardswish)

    def forward(self, inputs, id_stage):

        if id_stage > 0:
            att_outputs = [self.Restore(inputs[0])]
            for x, encoder in zip(inputs[1:], self.mhca_blks):
                # [B, C, H, W] -> [B, N, C]
                _, _, H, W = x.shape

                x = x.flatten(2).transpose(1, 2)
                att_outputs.append(encoder(x, size=(H, W)))

            for i in range(len(att_outputs)):
                if att_outputs[i].shape[2:] != att_outputs[0].shape[2:]:
                    att_outputs[i] = F.interpolate(att_outputs[i], size=att_outputs[0].shape[2:], mode='bilinear',
                                                   align_corners=True)

            out_concat = att_outputs[0] + att_outputs[1]
        else:
            out_concat = self.Restore(inputs[0] + inputs[1])

        out = self.aggregate(out_concat)

        return out


class Cls_head(nn.Module):
    """a linear layer for classification."""

    def __init__(self, embed_dim, num_classes):
        super().__init__()

        self.cls = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # (B, C, H, W) -> (B, C, 1)

        x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
        # Shape : [B, C]
        out = self.cls(x)
        return out


def dpr_generator(drop_path_rate, num_layers, num_stages):
    """Generate drop path rate list following linear decay rule."""
    dpr_list = [
        x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers))
    ]
    dpr = []
    cur = 0
    for i in range(num_stages):
        dpr_per_stage = dpr_list[cur:cur + num_layers[i]]
        dpr.append(dpr_per_stage)
        cur += num_layers[i]

    return dpr



class AFFormer(BaseModule):
    def __init__(
        self,
        img_size=224,
        num_stages=4,
        num_path=[4, 4, 4, 4],
        num_layers=[1, 1, 1, 1],
        embed_dims=[64, 128, 256, 512],
        mlp_ratios=[8, 8, 4, 4],
        num_heads=[8, 8, 8, 8],
        drop_path_rate=0.0,
        in_chans=3,
        num_classes=1000,
        strides=[4, 2, 2, 2],
        pretrained=None, init_cfg=None,
    ):
        super().__init__()
        if isinstance(pretrained, str):
            self.init_cfg = pretrained
        self.num_classes = num_classes
        self.num_stages = num_stages

        dpr = dpr_generator(drop_path_rate, num_layers, num_stages)

        self.stem = nn.Sequential(
            Conv2d_BN(
                in_chans,
                embed_dims[0] // 2,
                kernel_size=3,
                stride=2,
                pad=1,
                act_layer=nn.Hardswish,
            ),
            Conv2d_BN(
                embed_dims[0] // 2,
                embed_dims[0],
                kernel_size=3,
                stride=2,
                pad=1,
                act_layer=nn.Hardswish,
            ),
        )

        self.patch_embed_stages = nn.ModuleList([
            Patch_Embed_stage(
                embed_dims[idx],
                num_path=num_path[idx],
                isPool=True if idx == 1 else False,
                stage=idx,
            ) for idx in range(self.num_stages)
        ])

        self.mhca_stages = nn.ModuleList([
            MHCA_stage(
                embed_dims[idx],
                embed_dims[idx + 1]
                if not (idx + 1) == self.num_stages else embed_dims[idx],
                num_layers[idx],
                num_heads[idx],
                mlp_ratios[idx],
                num_path[idx],
                drop_path_list=dpr[idx],
                id_stage=idx,
            ) for idx in range(self.num_stages)
        ])

        # Classification head.
        # self.cls_head = Cls_head(embed_dims[-1], num_classes)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def init_weights(self):
        if isinstance(self.init_cfg, str):
            logger = get_root_logger()
            load_checkpoint(self, self.init_cfg, map_location='cpu', strict=False, logger=logger)
            
        else:
            self.apply(self._init_weights)
    def freeze_patch_emb(self):
        self.patch_embed1.requires_grad = False

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}  # has pos_embed may be better

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x):

        # x's shape : [B, C, H, W]

        x = self.stem(x)  # Shape : [B, C, H/4, W/4]
        out = []
        for idx in range(self.num_stages):
            att_inputs = self.patch_embed_stages[idx](x)
            x = self.mhca_stages[idx](att_inputs, idx)

            out.append(x)
            
        return out



@BACKBONES.register_module()
class afformer_base(AFFormer):
    def __init__(self, **kwargs):
        super(afformer_base, self).__init__(
                    img_size=224,
        num_stages=4,
        num_path=[1, 1, 1, 1],
        num_layers=[1, 2, 6, 2],
        embed_dims=[32, 96, 176, 216],
        mlp_ratios=[2, 2, 2, 2],
        num_heads=[8, 8, 8, 8], **kwargs)


@BACKBONES.register_module()
class afformer_small(AFFormer):
    def __init__(self, **kwargs):
        super(afformer_small, self).__init__(
                    img_size=224,
        num_stages=4,
        num_path=[1, 1, 1, 1],
        num_layers=[1, 2, 4, 2],
        embed_dims=[32, 64, 176, 216],
        mlp_ratios=[2, 2, 2, 2],
        num_heads=[8, 8, 8, 8], **kwargs)

        

@BACKBONES.register_module()
class afformer_tiny(AFFormer):
    def __init__(self, **kwargs):
        super(afformer_tiny, self).__init__(
                    img_size=224,
        num_stages=4,
        num_path=[1, 1, 1, 1],
        num_layers=[1, 2, 4, 2],
        embed_dims=[32, 64, 160, 216],
        mlp_ratios=[2, 2, 2, 2],
        num_heads=[8, 8, 8, 8], **kwargs)

Supongo que te gusta

Origin blog.csdn.net/m0_61899108/article/details/131154539
Recomendado
Clasificación