境界損失の原理とコード分析

論文:非常に不均衡なセグメンテーションの境界損失

導入

医療画像のセグメンテーション タスクには、通常、深刻なクラスの不均衡の問題があります。ターゲットの前景領域のサイズは、背景領域よりも数桁小さいことがよくあります。たとえば、次の図では、前景領域は 500 分の 1 以上小さいです背景領域よりも。

セグメンテーションに一般的に使用されるクロスエントロピー損失関数には、非常に不均衡な問題におけるよく知られた欠点があります。つまり、すべてのサンプルとカテゴリが同じ重要性を持つと想定しているため、トレーニングが不安定になり、決定境界が不均衡に偏ってしまうことがよくあります。カテゴリーの数が多い。クラスの不均衡の問題の場合、一般的な戦略は、多数のクラスをダウンサンプリングしてクラスの以前の分布のバランスを再調整することですが、この戦略ではトレーニング イメージの使用が制限されます。もう 1 つの戦略は、数値の小さいカテゴリの重みを大きくし、数値の大きいカテゴリの重みを小さくする重み付けです。この方法は、一部の不均衡な問題には有効ですが、極端に不均衡な問題の処理には適していません。データに関しては依然として困難が続いています。いくつかのピクセルで計算されたクロスエントロピー勾配にはノイズが含まれることが多く、いくつかのカテゴリに大きな重みを与えるとさらにノイズが増加し、トレーニングが不安定になります。

セグメンテーションにおけるもう 1 つの一般的な損失関数である Dice 損失は、通常、不均衡な医用画像セグメンテーションの問題において CE 損失よりも優れたパフォーマンスを発揮します。ただし、非常に小さな領域に遭遇した場合、ピクセルが誤って分類されると損失が大幅に減少し、最適化が不安定になる可能性があります。さらに、サイコロ損失は精度と再現率の調和平均に対応します。真陽性が変化しない場合、偽陽性と偽陰性は同じ重要性を持ちます。そのため、サイコロ損失は主に、これら 2 つのタイプのエラー数が同程度の状況に適しています。 。

貢献

CE 損失と Dice 損失はそれぞれ分布ベースと領域ベースの損失関数であり、本論文では領域空間ではなく等高線空間での距離メトリックの形式を採用した境界ベースの損失関数を提案します。境界損失は領域全体の積分ではなく、領域間の境界全体の積分を計算するため、非常に不均衡なセグメンテーション問題における領域損失に関連する問題を軽減できます。

ただし, CNN の領域ソフトマックス出力に基づいて対応する境界点をどのように表現するかは大きな課題です. この論文は, 曲線発展勾配流を計算するための離散グラフベースの最適化法の使用に触発され, 積分法を使用して輪郭を避けて境界の変化を計算します。点での局所微分計算により、最終的な境界損失はネットワーク出力領域のソフトマックス確率の一次関数和となるため、既存の領域損失と併用できます。

配合

 \(I:\Omega \subset \mathbb{R}^{2,3}\rightarrow \mathbb{R}\) は空間領域 \(\Omega\) の画像を表し、\(g:\Omega \ rightarrow \begin{Bmatrix}
0,1
\end{Bmatrix}\) は、ピクセル \(p\) がターゲット領域 \(G\subset \Omega\) (前景) に属している場合、画像のグラウンド トゥルース セグメンテーション バイナリ マップです。 area) 、\(g(p)=1\)、それ以外の場合は 0、つまり \(p\in\Omega\setminus G\) (背景領域) です。\(s_{\theta}:\Omega\rightarrow [0,1]\) はセグメンテーション ネットワークのソフトマックス確率出力を表し、\(S_{\theta}\subset\Omega\) は、セグメンテーション ネットワークによって出力された対応する前景領域を表します。モデル、つまり \( S_{\theta}=\begin{Bmatrix}
p\in\Omega|s_{\theta}(p)\geqslant \delta \ 
end{Bmatrix}\)、ここで \(\delta\)は事前に設定されたしきい値です。

私たちの目的は、領域境界空間 \(\Omega\) における距離計量の形式をとる境界損失関数 \(Dist(\partial G,\partial S_{\theta })\) を構築することです。ここで \(\部分 G\) は、グラウンド トゥルース領域 \(G\) の境界の表現 (境界上のすべての点の集合和など)、\(\partial S_{\theta }\) は、ネットワーク出力によって定義されるセグメンテーション領域。\(\partial S_{\theta }\) 上の点をネットワーク出力領域 \(s_{\theta }\) の微分可能関数として表現する方法は明らかではありません。形状空間における非対称 \(L_{2}\ distance\) の次の表現を考えてみましょう。これは、2 つの隣接する境界 \(\partial S\) と \(\partial G\) の間の距離の変化を評価します。

ここで、\(p\in\Omega\) は境界上の点です\(\partial G\)、\(y_{\partial S}(p)\) は境界上の対応する点です\(\partial S\ )、つまり、\(y_{\partial S}(p)\) は、次のように、\(\partial G\) と \(\partial S\) 上の点 \(p\) にある交点です。図 2(a) \(\left \| \cdot \right \|\) が \(L_{2}\) ノルムを表すことを示しています。輪郭上の点 \(\partial S\) を直接呼び出す他の輪郭距離と同様に、式 (2) を \(\partial S=\partial S_{\theta}\) の損失関数として直接使用することはできません。しかし、式 (2) の微分境界変化が積分法によって近似できることを証明するのは簡単です。積分法では、次のように輪郭上の点を含む微分計算を回避し、面積積分を使用して境界変化を表します。

ここで、 \(\bigtriangleup S\) は 2 つの等高線間の面積を表し、 \(D_{G}:\Omega\rightarrow \mathbb{R}^{+}\) は境界 \(\partial G \ )距離マップ、つまり \(D_{G}(q)\) は任意の点 \(q\in\Omega\) と最も近い点 \(z_{\partial G}( q 間の距離)\ を表します): \(D_{G}(q)=\left \| q-z_{\partial G}(q) \right \|\)、図 2(b) に示すように。

この近似を証明するには、法線と距離のグラフ \(2D_{G }(q)\) を次の変換で積分することで、\(\left \| y_{\partial S(p)}-p \right を得ることができます。 \|^{2}\)

式 (3) から、さらに次の式が得られます。

ここで、 \(s:\Omega\rightarrow \left \{ 0,1 \right \}\) は領域 \(S\) のバイナリ指標関数です: \(s(q)=1\ if\ q\in S \) はターゲットに属します。それ以外の場合は 0 です。\(\phi _{G}:\Omega\rightarrow \mathbb{R}\) は境界 \(\partial G\) のレベル集合表現です: \(\phi _{G}(q)=-D_ {G }(q)\ if\ q\in G\) それ以外の場合\(\phi _{G}(q)=D_{G}(q)\)。\(S=S_{\theta}\) の場合、つまり、ネットワーク \(s_{\theta}(q)\) のソフトマックス出力を使用して、式 (4) の \(s(q)\) を置き換えます。 、境界損失は次のようになります。

式 (4) の最後の項にはモデル パラメーターが含まれていないため、最後の項を削除したことに注意してください。レベルセット関数 \(\phi_{G}\) は、gt area \(G\) に基づいて事前に直接計算されます。境界損失は、 \(N\) クラスのセグメンテーション問題に対して一般的に使用される領域ベースの損失関数と組み合わせることができます

ここで、 \(\alpha \in\mathbb{R}\) は 2 つの損失のバランスをとる重みパラメータです。

式 (5) では、各点 \(q\) のソフトマックス出力は距離関数によって重み付けされます。領域ベースの損失関数では、この境界までの距離の情報は無視され、領域内の各点は境界の距離とサイズはすべて同じ重みで処理されます。

著者が提案した境界損失では、距離関数内のすべての負の値が保持され(gt 領域内のすべてのピクセルに対するモデルのソフトマックス予測が 1 である)、すべての正の値が破棄されます(つまり、つまり、背景に対するモデルのソフトマックス予測です。両方が 0 の場合、境界損失はグローバル最小値に達します。つまり、モデルのソフトマックス予測がグラウンド トゥルースを正確に出力するときに境界損失が最小になります。これにより、境界の有効性も検証されます。損失。

その後の実験でわかるように、良好な結果を得るには通常、境界損失と面積損失を組み合わせる必要があります。記事内で著者が説明している理由がよく分からないので、原文を掲載します。

 「前に説明したように、境界損失の大域的最適値は厳密に負の値に対応し、ソフトマックス確率は空ではない前景領域を生成します。しかし、ほとんどどこでもソフトマックス確率がほぼヌル値である空の前景は、次のように対応します。非常に低い勾配です したがって、この自明な解は極小値または鞍点に近いです。これが、境界損失を領域損失と統合する理由です。」

実験

地域別損失の比較

他の損失関数との比較実験では、 \alpha はリバランス戦略を使用します。つまり、初期値は 0.01 で、エポックごとに 0.01 ずつ増加します。

表から、クロスエントロピー損失、一般ダイス損失、フォーカス損失のいずれにおいても、境界損失を組み合わせることで一定の精度向上が見られ、境界損失の有効性が示されていることがわかります。 

\(\alpha\) の選択

著者は 3 つの異なる方法を比較しました。1 つは定数 \(\alpha\)、つまり \(\alpha\) の値は学習プロセス全体を通じて変化しません。2 つ目は \(\alpha\) を増加させる、つまり初期設定は 0 より大きいが比較的小さい値です。各エポックの後、 \(\alpha\) の値は徐々に増加しますが、局所的な損失の重みはトレーニングが終了するまで変化しません。 2 つの重み損失は​​同じです。3 つ目はリバランス \(\alpha\)、つまり \((1-\alpha)L_{R}+\alpha L_{B}\) の方法で 2 つの損失を結合し、増加させます。各エポック後の \(\alpha\) の値、トレーニングが進むにつれて、境界損失の重みはますます大きくなり、一方、領域損失の重みはますます小さくなります。実験結果は以下の通りです

リバランス戦略が最適な結果を達成したことがわかり、この戦略は他の領域の損失結果に関するすべての比較実験でも使用されました。

実装

このうちデータはグラウンドトゥルースですが、ここでは前景と背景の 2 つのカテゴリの場合のみを考えます。ロジッツはソフトマックス後の出力です。便宜上、これはモデル出力の各ピクセルを argmax またはしきい値を介して対応するカテゴリに分割することに相当します。実際、ここでの値は [0, 1 ] の間のソフトマックスの出力である必要があります。 。距離マップは、scipy ライブラリの distance_transform_edt 関数によって計算されます。この関数の概要については、  scipy.ndimage. distance_transform_edt および cv2. distanceTransform の使用法を参照してください。

import torch
import numpy as np
from torch import einsum
from torch import Tensor
from scipy.ndimage import distance_transform_edt as distance
from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union


# switch between representations
def probs2class(probs: Tensor) -> Tensor:
    b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
    assert simplex(probs)

    res = probs.argmax(dim=1)
    assert res.shape == (b, w, h)

    return res


def probs2one_hot(probs: Tensor) -> Tensor:
    _, C, _, _ = probs.shape
    assert simplex(probs)

    res = class2one_hot(probs2class(probs), C)
    assert res.shape == probs.shape
    assert one_hot(res)

    return res


def class2one_hot(seg: Tensor, C: int) -> Tensor:
    if len(seg.shape) == 2:  # Only w, h, used by the dataloader
        seg = seg.unsqueeze(dim=0)
    assert sset(seg, list(range(C)))

    b, w, h = seg.shape  # type: Tuple[int, int, int]

    res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32)
    assert res.shape == (b, C, w, h)
    assert one_hot(res)

    return res


def one_hot2dist(seg: np.ndarray) -> np.ndarray:
    assert one_hot(torch.Tensor(seg), axis=0)
    C: int = len(seg)

    res = np.zeros_like(seg)
    # res = res.astype(np.float64)
    for c in range(C):
        posmask = seg[c].astype(np.bool)

        if posmask.any():
            negmask = ~posmask
            res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
    return res


def simplex(t: Tensor, axis=1) -> bool:
    _sum = t.sum(axis).type(torch.float32)
    _ones = torch.ones_like(_sum, dtype=torch.float32)
    return torch.allclose(_sum, _ones)


def one_hot(t: Tensor, axis=1) -> bool:
    return simplex(t, axis) and sset(t, [0, 1])

    # Assert utils


def uniq(a: Tensor) -> Set:
    return set(torch.unique(a.cpu()).numpy())


def sset(a: Tensor, sub: Iterable) -> bool:
    return uniq(a).issubset(sub)


class SurfaceLoss():
    def __init__(self):
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
        self.idc: List[int] = [1]  # 这里忽略背景类  https://github.com/LIVIAETS/surface-loss/issues/3

    # probs: bcwh, dist_maps: bcwh
    def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor:
        assert simplex(probs)
        assert not one_hot(dist_maps)

        pc = probs[:, self.idc, ...].type(torch.float32)
        dc = dist_maps[:, self.idc, ...].type(torch.float32)

        multiplied = einsum("bcwh,bcwh->bcwh", pc, dc)

        loss = multiplied.mean()

        return loss


if __name__ == "__main__":
    data = torch.tensor([[[0, 0, 0, 0, 0, 0, 0],
                          [0, 1, 1, 0, 0, 0, 0],
                          [0, 1, 1, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0, 0]]])  # (b, h, w)->(1,4,7)

    data2 = class2one_hot(data, 2)  # (b, num_class, h, w): (1,2,4,7)
    data2 = data2[0].numpy()  # (2,4,7)
    data3 = one_hot2dist(data2)  # bcwh

    logits = torch.tensor([[[0, 0, 0, 0, 0, 0, 0],
                            [0, 1, 1, 1, 1, 1, 0],
                            [0, 1, 1, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0]]])  # (b, h, w)

    logits = class2one_hot(logits, 2)

    Loss = SurfaceLoss()
    data3 = torch.tensor(data3).unsqueeze(0)

    res = Loss(logits, data3, None)
    print('loss:', res)

特定の種類のターゲット エリアでは、距離マップを計算するときに、エリア外の距離はすべて正の値、エリア内の距離はすべて負の値になり、エリア境界から離れるほど絶対値が大きくなることに注意してください。価値。複数のカテゴリがある場合、距離マップはカテゴリごとに個別に計算され、各カテゴリの対象領域は前景値 1、その他の領域は背景値 0 として扱われます。理想的には、モデルはエリア外のすべてのピクセルを背景として予測し、つまりすべての予測が 0 であり、エリア内のすべてのピクセルを前景、つまり 1 として予測する必要があります。このとき、損失は負であり、グローバル最小値に達します。 。 

おすすめ

転載: blog.csdn.net/ooooocj/article/details/126560722