【IPMI 2023】医用画像セグメンテーションのための深層学習モデルにおける境界検出の再考

医用画像セグメンテーションのための深層学習モデルにおける境界検出の再考、IPMI 2023

通訳: IPMI 2023 香港科技大学 Chen Hao チームの新作 | CTO: 医療画像セグメンテーションにおける境界検出の役割を再考する (qq.com)

論文:  https://arxiv.org/abs/2305.00678

コード:  https://github.com/xiaofang007/CTO

導入

この論文では、畳み込みニューラル ネットワーク、ビジュアル トランスフォーマー、明示的な境界検出操作を組み合わせることにより、精度と効率の最適なバランスで高精度の画像セグメンテーションを実現する、新しいネットワーク アーキテクチャ 、 、 、 をCTO提案 し Convolutionます TransformerOperator

CTO は標準的なエンコーダ/デコーダセグメンテーション パラダイムに従い、エンコーダ ネットワークは一般的な CNN バックボーン構造を採用してローカル セマンティック情報をキャプチャし、軽量の ViT 補助ネットワークを使用して長距離の依存関係を統合します。境界学習能力を強化するために、本論文はさらに、専用の境界検出操作から得られた境界マスクを、復号学習プロセスをガイドするための明示的な監視として使用する、境界ガイド付きデコーダネットワークを提案します。

コンボリューション、トランスフォーマー、オペレーター (CTO)

CTO は、エンコーダ/デコーダのパラダイムに従い、スキップ接続を使用して、低レベルの機能をエンコーダからデコーダに集約します。エンコーダ ネットワークは、主流の CNN と補助的な ViT で構成されます。デコーダ ネットワークは、境界検出オペレータを使用して学習プロセスをガイドします。

  • 畳み込みニューラル ネットワークと軽量のビジュアル トランスフォーマーを組み合わせたデュアル ストリーム エンコーダーは、画像パッチ間の画像の局所的な特徴の依存関係と長距離の特徴の依存関係をそれぞれキャプチャします。
  • 境界検出演算子 (例: Sobel) を使用して、生成された境界マスクを通じて学習プロセスをガイドする演算子ガイド付きデコーダー。モデル全体がエンドツーエンド方式でトレーニングされます。

デュアルストリームエンコーダ

Res2NetCTO はまず、バックボーン ネットワークとして選択された畳み込みフローを構築し、ローカル機能の依存関係を取得します。

CTO は、軽量の Vision Transformer ベースの補助フローを使用して、異なるイメージ パッチ間の長距離の依存関係をキャプチャします。具体的には、異なるスケールの機能ブロックを入力として受け取る複数の並列軽量 Transformer ブロックで構成されます。すべての Transformer ブロックは、ブロック埋め込み層や Transformer エンコード層など、同様の構造を共有します。

LightViT のブロック埋め込み層は、入力特徴ブロックを埋め込みベクトルに変換し、空間次元をシーケンス次元に変換するために使用されます。Transformer エンコード層は、セルフ アテンション メカニズムを備えた機能ブロックをモデル化して、異なる機能ブロック間の長距離の依存関係をキャプチャするために使用されます。Transformer モジュールにセルフ アテンション メカニズムを導入することで、LightViT は機能ブロック間の相互作用を効果的にモデル化し、画像のグローバル コンテキスト情報を抽出できます。

境界誘導デコーダ

境界ガイド デコーダは、勾配演算子モジュールを使用して前景オブジェクトの境界情報を抽出します。次に、境界最適化モジュールを通じて、境界で強化された特徴がマルチレベル エンコーダの特徴と統合され、特徴空間におけるクラス内およびクラス間の一貫性を同時に特徴付け、エンコーダの表現能力を強化することを目的としています。特徴。このアプローチにより、デコーダはセグメンテーション結果を生成する際に境界情報をより適切に利用できるようになり、より正確なセグメンテーション結果が得られます。

境界拡張モジュール (BEM)

境界最適化モジュールは、高レベルの特徴と低レベルの特徴を入力として使用し、境界情報を抽出し、境界に無関係な情報をフィルターで除外します。Sobelオペレータを水平方向Gxと垂直方向Gyに適用してグラディエントマップを得る。具体的には、この論文では 2 つの 3*3 パラメータ固定畳み込みを採用し、ストライド 1 の畳み込み演算を適用します。これら 2 つの畳み込みは次のように定義されます。

次に、これら 2 つの畳み込みを入力特徴マップに適用して、勾配マップ Mx および My を取得します。次に、勾配マップが sigmoid 関数によって正規化され、入力特徴マップと融合されて、強化されたエッジ特徴マップ Fe が得られます。

このうち、丸数字は\シグマ シグモイド関数を表す要素ごとの乗算を表し、Mxy はチャネル次元に沿った Mx と My の結合です。次に、単純な積み重ねた畳み込み層を使用して、エッジ強調された特徴マップを直接融合できます。最後に、出力フィーチャ マップは GT 境界マップによって管理され、オブジェクト内のエッジ フィーチャが削除され、境界が強調されたフィーチャが生成されます。

境界注入モジュール (BIM)

BEM によって取得された境界強調特徴は、エンコーダーによって生成された特徴の画像表現能力を向上させるための事前知識として使用できます。BIM では、前景と背景のフィーチャの表現機能を容易にするために、デュアル パス境界融合スキームが導入されています。具体的には、BIM は 2 つの入力を受け取ります。境界強調フィーチャと、エンコーダー ネットワークからの対応するフィーチャと、前のデコーダー レイヤーからのフィーチャとのチャネル レベルの接続です。これら 2 つの入力は BIM に入力されます。BIM には、それぞれ前景と背景のフィーチャ表現を容易にする 2 つの別個のパスが含まれています。

  • フォアグラウンド パスの場合、2 つの入力をチャネル次元に沿って直接連結し、一連の Conv-BN-ReLU (畳み込み、バッチ正規化、ReLU アクティベーション) レイヤーを適用してフォアグラウンド特徴を取得します。
  • 背景パスの場合、背景注意コンポーネントは背景情報に選択的に焦点を当てるように設計されています。

前景パスは前景特徴 Ffg を取得します。背景パスは背景機能 Fbg を取得します。 

前景アテンション マップは、デコーダの前層の特徴マップからシグモイドを介して取得され、背景アテンション マップは、前景アテンション マップを 1 から減算することによって取得されます。最後に、前景特徴 Ffg、背景特徴 Fbg、および前の層のデコーダ特徴が連結されて、この層の出力が取得されます。

損失関数

CTOは、内部セグメンテーションと境界セグメンテーションを含むマルチタスク モデルで、これら 2 つのタスクを共同で最適化する全体的な損失関数を定義します。

全体の損失は、主な内部セグメント化損失 L_seg と境界損失 L_bnd で構成されます。境界検出損失では、BEM からの予測のみが考慮され、このモジュールはエンコーダーの高レベルの特徴マップと低レベルの特徴マップを入力として受け取ります。

内部セグメンテーション損失

L_seg は、クロスエントロピー損失 L_CE と平均 IoU 損失 L_mIoU の加重合計です。

境界の喪失

境界損失 L_bnd は、境界検出における前景ピクセルと背景ピクセル間のカテゴリーの不均衡を考慮するため、Dice 損失が使用されます。

実験

 

キーコード

CTO_net.py

# https://github.com/xiaofang007/CTO/blob/main/CTOTrainer/network/CTO_net.py

class ConvBNR(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
        super(ConvBNR, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(inplanes, planes, kernel_size, stride=stride, padding=dilation, dilation=dilation, bias=bias),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class Conv1x1(nn.Module):
    def __init__(self, inplanes, planes):
        super(Conv1x1, self).__init__()
        self.conv = nn.Conv2d(inplanes, planes, 1)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        return x


class EAM(nn.Module):
    def __init__(self):
        super(EAM, self).__init__()
        self.reduce1 = Conv1x1(256, 64)
        self.reduce4 = Conv1x1(512, 256)
        self.block = nn.Sequential(
            ConvBNR(320 + 64, 256, 3),
            ConvBNR(256, 256, 3),
            nn.Conv2d(256, 1, 1))

    def forward(self, x1, x11, p2):
        size = x1.size()[2:]
        x1 = self.reduce1(x1)
        x11 = self.reduce1(x11)
        p2 = self.reduce4(p2)
        p2 = F.interpolate(p2, size, mode='bilinear', align_corners=False)
        out = torch.cat((x1, x11), dim=1)
        out = torch.cat((out, p2), dim=1)
        out = self.block(out)

        return out



class EFM(nn.Module):
    def __init__(self, channel):
        super(EFM, self).__init__()
        t = int(abs((log(channel, 2) + 1) / 2))
        k = t if t % 2 else t + 1
        self.conv2d = ConvBNR(channel, channel, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, c, att):
        if c.size() != att.size():
            att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
        x = c * att + c
        x = self.conv2d(x)
        wei = self.avg_pool(x)
        wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        wei = self.sigmoid(wei)
        x = x * wei

        return x

class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class DM(nn.Module):
    def __init__(self):
        super(DM, self).__init__()
        self.predict3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(),
            nn.Conv2d(64, 1, kernel_size=1)
        )
        self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1)
        self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1)

    def forward(self, xr, dualattention):
        crop_3 = F.interpolate(dualattention, xr.size()[2:], mode='bilinear', align_corners=False)
        re3_feat = self.predict3(torch.cat([xr, crop_3], dim=1))
        x = -1*(torch.sigmoid(crop_3)) + 1
        x = x.expand(-1, 64, -1, -1).mul(xr)
        x = F.relu(self.ra2_conv2(x))
        x = F.relu(self.ra2_conv3(x))
        ra3_feat = self.ra2_conv4(x)
        x = ra3_feat + crop_3 + re3_feat


        return x


class _DAHead(nn.Module):
    def __init__(self, in_channels, nclass, aux=True, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
        super(_DAHead, self).__init__()
        self.aux = aux
        inter_channels = in_channels // 4
        self.conv_p1 = nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.conv_c1 = nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.pam = _PositionAttentionModule(inter_channels, **kwargs)
        self.cam = _ChannelAttentionModule(**kwargs)
        self.conv_p2 = nn.Sequential(
            nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.conv_c2 = nn.Sequential(
            nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
            norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
        self.out = nn.Sequential(
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, nclass, 1)
        )
        if aux:
            self.conv_p3 = nn.Sequential(
                nn.Dropout(0.1),
                nn.Conv2d(inter_channels, nclass, 1)
            )
            self.conv_c3 = nn.Sequential(
                nn.Dropout(0.1),
                nn.Conv2d(inter_channels, nclass, 1)
            )

    def forward(self, x):
        feat_p = self.conv_p1(x)
        feat_p = self.pam(feat_p)
        feat_p = self.conv_p2(feat_p)

        feat_c = self.conv_c1(x)
        feat_c = self.cam(feat_c)
        feat_c = self.conv_c2(feat_c)

        feat_fusion = feat_p + feat_c

        outputs = []
        fusion_out = self.out(feat_fusion)
        outputs.append(fusion_out)
        if self.aux:
            p_out = self.conv_p3(feat_p)
            c_out = self.conv_c3(feat_c)
            outputs.append(p_out)
            outputs.append(c_out)

        return tuple(outputs)

def run_sobel(conv_x, conv_y, input):
    g_x = conv_x(input)
    g_y = conv_y(input)
    g = torch.sqrt(torch.pow(g_x, 2) + torch.pow(g_y, 2))
    return torch.sigmoid(g) * input

def get_sobel(in_chan, out_chan):
    '''
    filter_x = np.array([
        [3, 0, -3],
        [10, 0, -10],
        [3, 0, -3],
    ]).astype(np.float32)
    filter_y = np.array([
        [3, 10, 3],
        [0, 0, 0],
        [-3, -10, -3],
    ]).astype(np.float32)
    '''
    filter_x = np.array([
        [1, 0, -1],
        [2, 0, -2],
        [1, 0, -1],
    ]).astype(np.float32)
    filter_y = np.array([
        [1, 2, 1],
        [0, 0, 0],
        [-1, -2, -1],
    ]).astype(np.float32)
    filter_x = filter_x.reshape((1, 1, 3, 3))
    filter_x = np.repeat(filter_x, in_chan, axis=1)
    filter_x = np.repeat(filter_x, out_chan, axis=0)

    filter_y = filter_y.reshape((1, 1, 3, 3))
    filter_y = np.repeat(filter_y, in_chan, axis=1)
    filter_y = np.repeat(filter_y, out_chan, axis=0)

    filter_x = torch.from_numpy(filter_x)
    filter_y = torch.from_numpy(filter_y)
    filter_x = nn.Parameter(filter_x, requires_grad=False)
    filter_y = nn.Parameter(filter_y, requires_grad=False)
    conv_x = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
    conv_x.weight = filter_x
    conv_y = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
    conv_y.weight = filter_y
    sobel_x = nn.Sequential(conv_x, nn.BatchNorm2d(out_chan))
    sobel_y = nn.Sequential(conv_y, nn.BatchNorm2d(out_chan))
    return sobel_x, sobel_y

class GlobalFilter(nn.Module):
    def __init__(self, dim=32, h=64, w=33, fp32fft=True):
        super().__init__()
        self.complex_weight = nn.Parameter(
            torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02
        )
        self.w = w
        self.h = h
        self.fp32fft = fp32fft

    def forward(self, x):
        b, _, a, b = x.size()
        x = x.permute(0, 2, 3, 1).contiguous()

        if self.fp32fft:
            dtype = x.dtype
            x = x.to(torch.float32)

        x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
        #print(x.shape)
        weight = torch.view_as_complex(self.complex_weight)
       # print(x.shape)
        #print(weight.shape)
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm="ortho")

        if self.fp32fft:
            x = x.to(dtype)

        x = x.permute(0, 3, 1, 2).contiguous()

        return x

class ERB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ERB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x, relu=True):
        x = self.conv1(x)
        res = self.conv2(x)
        res = self.bn(res)
        res = self.relu(res)
        res = self.conv3(res)
        if relu:
            return self.relu(x + res)
        else:
            return x+res

class _PositionAttentionModule(nn.Module):
    """ Position attention module"""

    def __init__(self, in_channels, **kwargs):
        super(_PositionAttentionModule, self).__init__()
        self.conv_b = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.conv_c = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.conv_d = nn.Conv2d(in_channels, in_channels, 1)
        self.alpha = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, _, height, width = x.size()
        feat_b = self.conv_b(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        feat_c = self.conv_c(x).view(batch_size, -1, height * width)
        attention_s = self.softmax(torch.bmm(feat_b, feat_c))
        feat_d = self.conv_d(x).view(batch_size, -1, height * width)
        feat_e = torch.bmm(feat_d, attention_s.permute(0, 2, 1)).view(batch_size, -1, height, width)
        out = self.alpha * feat_e + x

        return out


class _ChannelAttentionModule(nn.Module):
    """Channel attention module"""

    def __init__(self, **kwargs):
        super(_ChannelAttentionModule, self).__init__()
        self.beta = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, _, height, width = x.size()
        feat_a = x.view(batch_size, -1, height * width)
        feat_a_transpose = x.view(batch_size, -1, height * width).permute(0, 2, 1)
        attention = torch.bmm(feat_a, feat_a_transpose)
        attention_new = torch.max(attention, dim=-1, keepdim=True)[0].expand_as(attention) - attention
        attention = self.softmax(attention_new)

        feat_e = torch.bmm(attention, feat_a).view(batch_size, -1, height, width)
        out = self.beta * feat_e + x

        return out
        
class EAM(nn.Module):
    def __init__(self):
        super(EAM, self).__init__()
        self.reduce1 = Conv1x1(256, 64)
        self.reduce4 = Conv1x1(2048, 256)
        self.block = nn.Sequential(
            ConvBNR(256 + 64, 256, 3),
            ConvBNR(256, 256, 3),
            nn.Conv2d(256, 1, 1))

    def forward(self, x4, x1):
        size = x1.size()[2:]
        x1 = self.reduce1(x1)
        x4 = self.reduce4(x4)
        x4 = F.interpolate(x4, size, mode='bilinear', align_corners=False)
        out = torch.cat((x4, x1), dim=1)
        out = self.block(out)

        return out

def attention(query, key, value):
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
        query.size(-1)
    )
    p_attn = F.softmax(scores, dim=-1)
    p_val = torch.matmul(p_attn, value)
    return p_val, p_attn

class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

    def __init__(self, patchsize, d_model):
        super().__init__()
        self.patchsize = patchsize
        self.query_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.value_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.key_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.output_linear = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
            nn.BatchNorm2d(d_model),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        b, c, h, w = x.size()#8,255,64,64
        d_k = c // len(self.patchsize)
        output = []
        _query = self.query_embedding(x)#8,32,80,80
        _key = self.key_embedding(x)#8,32,80,80
        _value = self.value_embedding(x)#8,32,80,80
        attentions = []
        for (width, height), query, key, value in zip(
            self.patchsize,
            torch.chunk(_query, len(self.patchsize), dim=1),
            torch.chunk(_key, len(self.patchsize), dim=1),
            torch.chunk(_value, len(self.patchsize), dim=1),
        ):
            #print('-----------width, height):',x.size())
           # print('-----------x.size()):',x.size())
            
            #print('-----------len(self.patchsize):',len(self.patchsize))  # 4
            
            #print('-----------_query):',_query.shape)   #8,256,64,64
            
            #print('-----------query):',query.shape)  #8,64,64,64
            
            out_w, out_h = w // width, h // height#
            ## 1) embedding and reshape
            query = query.view(b, d_k, out_h, height, out_w, width)
           # print('-----------query):',query.shape)
            
           # print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
            query = (
                query.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            key = key.view(b, d_k, out_h, height, out_w, width)
            key = (
                key.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            value = value.view(b, d_k, out_h, height, out_w, width)
            value = (
                value.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            y, _ = attention(query, key, value)

            # 3) "Concat" using a view and apply a final linear.
            y = y.view(b, out_h, out_w, d_k, height, width)
            y = y.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, d_k, h, w)
            attentions.append(y)
            output.append(y)

        output = torch.cat(output, 1)
        self_attention = self.output_linear(output)

        return self_attention



class TransformerBlock(nn.Module):
    """
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, patchsize, in_channel=256):
        super().__init__()
        self.attention = MultiHeadedAttention(patchsize, d_model=in_channel)
        self.feed_forward = FeedForward2D(
            in_channel=in_channel, out_channel=in_channel
        )

    def forward(self, rgb):
        self_attention = self.attention(rgb)
        output = rgb + self_attention
        output = output + self.feed_forward(output)
        return output

class PatchTrans(BaseNetwork):
    def __init__(self, in_channel, in_size):#32,80
        super(PatchTrans, self).__init__()
        self.in_size = in_size#80

        patchsize = [
              (32,32),#80,80
              (16,16),#40,40
              (8,8),#20,20
              (4,4),#10,10
        ]

        self.t = TransformerBlock(patchsize, in_channel=in_channel)

    def forward(self, enc_feat):
        output = self.t(enc_feat)
        return output

class multi(nn.Module):
    def __init__(self, channel):
        super(EFM, self).__init__()
        t = int(abs((log(channel, 2) + 1) / 2))
        k = t if t % 2 else t + 1
        self.conv2d = ConvBNR(channel, channel, 3)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1d = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, c, att):
        if c.size() != att.size():
            att = F.interpolate(att, c.size()[2:], mode='bilinear', align_corners=False)
        x = c * att 
        #x = self.conv2d(x)
        #wei = self.avg_pool(x)
        #wei = self.conv1d(wei.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        #wei = self.sigmoid(wei)
        #x = x * wei

        return x

class CTO(nn.Module):
    def __init__(self,seg_classes):
        super(CTO, self).__init__()
        self.resnet = res2net50_v1b_26w_4s(pretrained=True)
        # if self.training:
        # self.initialize_weights()
        self.fft = GlobalFilter(dim = 3 , h=256, w=129, fp32fft= True)
        
        self.multi_trans = PatchTrans(in_channel=256,in_size=64)
        
        
        
        self.num_class = seg_classes
        self.eam = EAM()
        self.sobel_x1, self.sobel_y1 = get_sobel(256, 1)
        self.sobel_x2, self.sobel_y2 = get_sobel(512, 1)
        self.sobel_x3, self.sobel_y3 = get_sobel(1024, 1)
        self.sobel_x4, self.sobel_y4 = get_sobel(2048, 1)
        
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
        self.upsample_3 = nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
        
        self.erb_db_1 = ERB(256, self.num_class)
        self.erb_db_2 = ERB(512, self.num_class)
        self.erb_db_3 = ERB(1024, self.num_class)
        self.erb_db_4 = ERB(2048, self.num_class)
        
        self.head = _DAHead(2048+256, 2048, aux=False)

        

        self.reduce1 = Conv1x1(256, 64)
        self.reduce2 = Conv1x1(512, 64)
        self.reduce3 = Conv1x1(1024, 64)
        self.reduce4 = Conv1x1(2048, 64)
        self.reduce5 = Conv1x1(2048, 1)

        self.dm1 = DM()
        self.dm2 = DM()
        self.dm3 = DM()
        self.dm4 = DM()

        self.predictor1 = nn.Conv2d(64, self.num_class, 1)
        self.predictor2 = nn.Conv2d(64, self.num_class, 1)
        self.predictor3 = nn.Conv2d(64, self.num_class, 1)
        self.predictor4 = nn.Conv2d(64, self.num_class, 1)

    # def initialize_weights(self):
    # model_state = torch.load('./models/resnet50-19c8e357.pth')
    # self.resnet.load_state_dict(model_state, strict=False)

    def forward(self, x):
        fft_fea = self.fft(x)#3,256,256
        x1, x2, x3 ,x4= self.resnet(x)#[16, 256, 64, 64]  [16, 512, 32, 32]   [16, 1024, 16, 16]   [16, 2048, 8, 8]
        
        trans = self.multi_trans(x1)#16,256,64,64
        
        s1 = run_sobel(self.sobel_x1, self.sobel_y1, x1)
        s4 = run_sobel(self.sobel_x4, self.sobel_y4, x4)
       
        edge = self.eam(s4, s1)
        edge_att = torch.sigmoid(edge)#[16, 1, 64, 64]
        
        trans = F.interpolate(trans,x4.size()[2:], mode='bilinear', align_corners=False)#256,8,8
        dual_attention = self.head(torch.cat([trans, x4], dim=1))[0]  #2048,8,8
        
        x1a = x1*edge_att
        edge_att2 = F.interpolate(edge_att, x2.size()[2:], mode='bilinear', align_corners=False)
        x2a = x2*edge_att2
        edge_att3 = F.interpolate(edge_att, x3.size()[2:], mode='bilinear', align_corners=False)
        x3a = x3*edge_att3
        
        #x1a = self.efm1(x1, edge_att)
        #x2a = self.efm2(x2, edge_att)
       # x3a = self.efm3(x3, edge_att)
       # x4a = self.efm4(x4, edge_att)
        
        x1r = self.reduce1(x1a)  
        x2r = self.reduce2(x2a)#128,32,32
        x3r = self.reduce3(x3a)#256,16,16
        
        dual_attention = self.reduce4(dual_attention)
       
        c3 = self.dm3(x3r, dual_attention) #256 16 16
        c2 = self.dm2(x2r, c3)  #128 32 32
        c1 = self.dm1(x1r, c2) #64 64 64
        

        o3 = self.predictor3(c3)
        o3 = F.interpolate(o3, scale_factor=16, mode='bilinear', align_corners=False)
        o2 = self.predictor2(c2)
        o2 = F.interpolate(o2, scale_factor=8, mode='bilinear', align_corners=False) 
        o1 = self.predictor1(c1)
        o1 = F.interpolate(o1, scale_factor=4, mode='bilinear', align_corners=False)
        oe = F.interpolate(edge_att, scale_factor=4, mode='bilinear', align_corners=False)

        return  o3, o2, o1, oe

おすすめ

転載: blog.csdn.net/m0_61899108/article/details/131155202