リアルタイム セマンティック セグメンテーション ネットワーク STDC の原則とコード分析 (CVPR 2021)

論文:リアルタイム セマンティック セグメンテーションのための BiSeNet の再考

公式実装:GitHub - MichaelFan01/STDC-Seg: CVPR2021 論文「Re Thinking BiSeNet For Real-time Semantic Segmentation」のソースコード

サードパーティ実装:mmsegmentation/mmseg/models/decode_heads/stdc_head.py at main · open-mmlab/mmsegmentation · GitHub

既存の問題 

リアルタイム推論を実現するために、多くのリアルタイム セマンティック セグメンテーション モデルは軽量バックボーン ネットワークを使用しますが、タスク固有の設計が欠如しているため、分類タスクから借用したこれらの軽量バックボーン ネットワークは、セグメンテーション問題の解決には適さない可能性があります。

軽量のバックボーンを選択することに加えて、入力画像のサイズを制限することも推論速度を向上させる一般的な方法ですが、エッジ付近の細部や小さなオブジェクトは無視されがちです。この問題を解決するために、BiSeNet は低レベルの詳細情報と高レベルの意味情報を組み合わせるためにマルチパス構造を採用していますが、低レベルの特徴を取得するために追加のパスを追加するには時間がかかり、補助パスは常に不足しています。低レベルの情報の案内。

この記事の革新性

本論文では、Short-Term Dense Concatenate モジュール (STDC モジュール) と呼ばれる新しい構造を設計します。このモジュールは、少数のパラメータを通じてさまざまなサイズの受容野とマルチスケール情報を取得できます。STDC ネットワークは、STDC モジュールを U-net アーキテクチャにシームレスに統合することで得られ、セマンティック セグメンテーション タスクにおけるネットワーク パフォーマンスが大幅に向上します。

デコード段階では、この論文は追加のパスを追加する方法を放棄しますが、詳細ガイダンスを使用して低レベルの空間詳細情報の学習をガイドします。まず、Detail Aggregation モジュールを使用して詳細のグランド トゥルースを生成し、次に bce 損失と dice 損失を使用して詳細情報の学習を最適化します。これは副次情報の学習とみなすことができ、この側- 推論中に情報は必要ありません。

メソッドの紹介

エンコーディングネットワークの設計

短期高密度連結モジュール

STDCモジュールの構造を図(3)の(b)(c)に示します。

各モジュールは複数のブロックに分割されており、 \(ConvX_{i}\) は \(i\) 番目のブロックの計算を表すため、 \(i\) 番目のブロックの出力計算は次のようになります。

\(x_{i-1}\) と \(x_{i}\) はそれぞれ \(i\) ブロックの入力と出力です。 \(ConvX\) には畳み込み層、BN 層、 ReLU 活性化層 \(k_{i}\) はコンボリューション カーネル サイズです。

STDC モジュールでは、最初のブロックのコンボリューション カーネル サイズは 1 で、残りは 3 です。STDC モジュールの出力チャネルの数が \(N\) であると仮定します。ただし、最後の畳み込み層の畳み込みカーネルの数が前の畳み込み層の畳み込みカーネルの数と同じである場合、\( i\) ブロックは \(N /2^{i}\) です。分類タスクでは、通常、上位層ほど多くのチャネルがあります。しかし、セグメンテーションタスクでは、可変の受容野サイズとマルチスケール情報にもっと注意を払います。下位層は、小さな受容野でより粒度の細かい情報をエンコードするのに十分なチャネルを必要としますが、より大きな受容野を持つ上位層は、より多くの注意を払います。ハイレベルセマンティクス ちなみに、下位層と同じ数のチャネルを設定すると、情報の冗長性が生じる可能性があります。ダウンサンプリングは Block2 でのみ行われます。特徴情報を充実させるために、\(x_{1}\) から \(x_{n}\) までの特徴が STDC モジュールの出力としてスキップパスによって連結されます。

ネットワークアーキテクチャ

ネットワークの完全な構造を図 3(a) に示します。これには合計 6 つのステージが含まれており、stride=2 のダウンサンプリングがステージ 1 ~ 5 で実行され、ステージ 6 はグローバルな ConvX を通じて取得されます。平均プーリングと 2 つの完全に接続された層 最終予測ロジット。

通常、Stage1&2 は外観特徴を抽出するための下位層として使用されますが、効率を追求するため、各ステージに畳み込みブロックは 1 つだけあります。ステージ 3、4、および 5 の STDC モジュールの数は慎重に調整および決定され、各ステージの最初の STDC モジュールがダウンサンプリングされます。STDC ネットワークの詳細な構造を表 2 に示します。

デコーダの設計

セグメンテーション アーキテクチャ

この記事では、事前トレーニングされた STDC ネットワークをエンコーダーのバックボーンとして使用し、BiSeNet のコンテキスト パスを使用してコンテキスト情報をエンコードします。

図 4(a) に示すように、著者はステージ 3、4、および 5 を使用して、それぞれ 1/8、1/16、および 1/32 のダウンサンプリング レートで特徴マップを生成します。次に、グローバル平均プーリングを使用して、大きな受容野を持つグローバルなコンテキスト情報を提供します。次に、U 字型構造を使用してアップサンプリングし、エンコード ステージ (ステージ 4、5) の対応する部分と融合します。ここでは、 BiSeNet のAttendee Refine モジュールを借用して、stage4、5 の機能をさらに改良します。最終予測では、BiSeNet の機能融合モジュールも借用して、エンコード ステージ stage3 の 1/8 のダウンサンプリング レートの特徴とデコード ステージの対応する出力を融合します。 

最終的な Seg Head には、最終的な N 次元出力を取得するための 3x3 Conv-BN-ReLU と 1x1 conv が含まれています。N はカテゴリの数です。損失関数はクロスエントロピーであり、オンラインハードサンプルマイニングOHEMが使用されます。

STDC ネットワークは BiSeNet の全体構造を借用し、BiSeNet の ARM モジュールと FFM モジュールを直接使用していることがわかります。BiSeNet の構造を以下の図に示します。図 4 と比較すると、STDC ネットワークは BiSeNet 内の空間パスとコンテキスト パスを 1 つに結合し、浅い出力が空間パスとして使用され、ディープ GAP の出力をコンテキスト パスとして使用し、ネットワークの構造を再設計しました。BiSeNet の導入については、BiSeNet v1 の原則とコード解釈_注意改良モジュール_00000cj のブログ - CSDN ブログでご覧いただけます。

低レベル機能の詳細なガイダンス

図 5 に示すように、(c) は STDC のステージ 3 のヒート マップであり、BiSeNet の空間パスと比較すると、多くの詳細があることがわかります。低レベルの学習空間情報。具体的には、図 4(c) に示すように、詳細の予測は 2 つのカテゴリのセグメンテーション タスクとしてモデル化されます。まず、ラプラシアン オペレーターを使用して、セグメンテーション タスクの元のグランド トゥルースから詳細マップのグランド トゥルースが生成されます。図 4(a) に示すように、ステージ 3 で詳細ヘッドが挿入されて詳細特徴マップが生成され、その後、詳細 gt が空間詳細の学習をガイドするために使用されます。図 5(d) に示すように、詳細ガイダンス モジュールを追加すると、詳細情報がさらに豊富になります。

グラウンドトゥルースの生成の詳細

詳細 gt の具体的な生成プロセスを図 4(c) に示します。元のセグメント化された GT に対して、異なるステップ長を持つラプラシアン演算子を使用して、マルチスケール詳細情報を取得します。ラプラシアン カーネルは、図 4(e) に​​示されています。次に、元のサイズにアップサンプリングしてから、トレーニング可能な 1x1 畳み込みを使用して、さまざまなスケールの詳細情報を融合し、最後に 0.1 のしきい値を使用して、最終的なバイナリ詳細のグラウンドトゥルースを取得します。

詳細の損失

詳細のピクセルは非詳細のピクセルよりもはるかに多いため、これはクラスの不均衡の問題です。重み付けされたクロスエントロピーの結果はそれほど正確ではないため、著者はクロスエントロピーとダイス損失を組み合わせて詳細の学習を最適化します。サイコロの損失は前景/背景ピクセルの数に影響されないため、カテゴリの不均衡の問題を軽減できます。詳細な損失は次のとおりです

詳細ヘッドには、具体的には 3x3 Conv-BN-ReLU と 1x1 conv が含まれており、推論段階では詳細ヘッドを直接破棄できます。

実験結果

ImageNet 上の STDC ネットワークとその他の軽量モデルの結果を表 5 に示します。STDC が最高の精度と速度のバランスを達成していることがわかります。

都市景観に関する結果を表 6 に示します。

 

他のリアルタイム セグメンテーション モデルと比較して、STDC は同じ速度で最高の精度を実現します。

コード分​​析

ここでは、MMSeg での実装を例に挙げます。バックボーン ネットワーク stdc ネットワークの実装は、mmseg/models/backbones/stdc.py にあります。具体的な実装プロセスは比較的単純なので、ここでは詳しく説明しません。なお、stride=2の場合、前述したように各ステージの最初のstdcモジュールのブロック2でダウンサンプリングが行われますが、mmsegの実装では元の conv にはstride=2が設定されておらず、元の conv の前に stride=2 の conv を追加します。

ARM モジュールと FFM モジュールは両方とも変更せずに BiSeNet v1 に実装されています。詳細については、「BiSeNet v1 の原理とコード解釈」を参照してください。

入力バッチサイズ = 16、入力サイズは 480x480、ネットワークの最終出力は次のように設定します。

outputs = [outs[0]] + list(arms_out) + [feat_fuse]
# (16,256,60,60) + [(16,128,30,30),(16,128,60,60)] + (16,256,60,60)

このうち、outs[0]はstage3の出力であり、後でDetail Headを接続する必要があります。arm_out は 2 つの ARM モジュールの出力であり、mmseg の実装では、FCN がこれら 2 つの出力を監視する補助ヘッドとして使用され (論文には記載されていません)、推論段階は削除されています。feat_fuse は、FFM 融合の後、FCN が続き、bce 損失 + ダイス損失で最適化された後の空間情報とコンテキスト情報の出力です。

Detail Head のグラウンド トゥルースの生成コードは次のとおりです。論文では、ステップ長の異なる 3 つのラプラシアン畳み込みの後、トレーニング可能な 1x1 畳み込みを使用して 3 つを融合すると述べられていますが、公式実装では、他の 3 番目では、 -party 実装では、トレーニング可能な 1x1 畳み込みが使用されないため、トレーニング不可能なパラメーターが事前に設定され、トレーニングによって更新されないfusion_kernelも融合に使用されます。

class STDCHead(FCNHead):
    """This head is the implementation of `Rethinking BiSeNet For Real-time
    Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.

    Args:
        boundary_threshold (float): The threshold of calculating boundary.
            Default: 0.1.
    """

    def __init__(self, boundary_threshold=0.1, **kwargs):
        super().__init__(**kwargs)
        self.boundary_threshold = boundary_threshold
        # Using register buffer to make laplacian kernel on the same
        # device of `seg_label`.
        self.register_buffer(
            'laplacian_kernel',
            torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1],
                         dtype=torch.float32,
                         requires_grad=False).reshape((1, 1, 3, 3)))
        self.fusion_kernel = torch.nn.Parameter(
            torch.tensor([[6. / 10], [3. / 10], [1. / 10]],
                         dtype=torch.float32).reshape(1, 3, 1, 1),
            requires_grad=False)

    def loss_by_feat(self, seg_logits: Tensor,
                     batch_data_samples: SampleList) -> dict:
        """Compute Detail Aggregation Loss."""
        # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
        # parameters. However, it is a constant in original repo and other
        # codebase because it would not be added into computation graph
        # after threshold operation.
        seg_label = self._stack_batch_gt(batch_data_samples).to(
            self.laplacian_kernel)  # (16,1,480,480)
        boundary_targets = F.conv2d(
            seg_label, self.laplacian_kernel, padding=1)
        boundary_targets = boundary_targets.clamp(min=0)
        boundary_targets[boundary_targets > self.boundary_threshold] = 1
        boundary_targets[boundary_targets <= self.boundary_threshold] = 0

        boundary_targets_x2 = F.conv2d(
            seg_label, self.laplacian_kernel, stride=2, padding=1)
        boundary_targets_x2 = boundary_targets_x2.clamp(min=0)

        boundary_targets_x4 = F.conv2d(
            seg_label, self.laplacian_kernel, stride=4, padding=1)
        boundary_targets_x4 = boundary_targets_x4.clamp(min=0)

        boundary_targets_x4_up = F.interpolate(
            boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
        boundary_targets_x2_up = F.interpolate(
            boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')

        boundary_targets_x2_up[
            boundary_targets_x2_up > self.boundary_threshold] = 1
        boundary_targets_x2_up[
            boundary_targets_x2_up <= self.boundary_threshold] = 0

        boundary_targets_x4_up[
            boundary_targets_x4_up > self.boundary_threshold] = 1
        boundary_targets_x4_up[
            boundary_targets_x4_up <= self.boundary_threshold] = 0

        boundary_targets_pyramids = torch.stack(
            (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
            dim=1)  # (16,3,1,480,480)

        boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)  # (16,3,480,480)
        boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
                                           self.fusion_kernel)

        boudary_targets_pyramid[
            boudary_targets_pyramid > self.boundary_threshold] = 1
        boudary_targets_pyramid[
            boudary_targets_pyramid <= self.boundary_threshold] = 0

        seg_labels = boudary_targets_pyramid.long()
        batch_sample_list = []
        for label in seg_labels:
            seg_data_sample = SegDataSample()
            seg_data_sample.gt_sem_seg = PixelData(data=label)
            batch_sample_list.append(seg_data_sample)

        loss = super().loss_by_feat(seg_logits, batch_sample_list)
        return loss

おすすめ

転載: blog.csdn.net/ooooocj/article/details/131502465
おすすめ