セマンティック セグメンテーション シリーズ 25-BiSeNetV2 (pytorch 実装)

BiSeNetV1 ( Semantic Segmentation Series 16-BiSeNetV1 ) に続き、BiSeNetV2 は 2021 年の IJCV で公開されます。

論文リンク:BiSeNetV2

V1 バージョンと比較して、V2 バージョンでは、ダウンサンプリング戦略、畳み込みタイプ、および機能融合において多くの改善が行われています。

この記事では以下を紹介します。

  • BiSeNetV2 がセマンティック ブランチと詳細ブランチを設計する方法。
  • BiSeNetV2 はアグリゲーション レイヤーをどのように設計して機能の融合を完了しますか。
  • BiSeNetV2 がモデルのトレーニングを支援するために補助損失を設計する方法。
  • BiSeNetV2 のコード実装とアプリケーション。

目次

論文セクション

引用

モデル

バックボーン詳細ブランチ

バックボーン セマンティック ブランチ

集約層

セグメンテーション ヘッド

ブースター(補欠)

BiSeNetV2 の実装とその Camvid への適用

BiSeNetV2 の実装

Camvid データセット

訓練

結果


論文セクション

引用

BiSeNetV1 バージョンの双方向セグメンテーション構造は、リアルタイムのセグメンテーション タスクで良好な結果を達成しています.このネットワーク構造は、推論の速度を損なうことなく、低レベルの詳細と高レベルのセマンティクスを保持できます.正確さを達成するためのバランスが取れています.セマンティック セグメンテーション タスクと高速な推論速度。

したがって、リアルタイムのセマンティック セグメンテーションを実現するために、双方向ベースのセグメンテーション ネットワークである BiSeNetV2 が提案されています。

BiSeNetV1 の最初のバージョンとの比較:

  • V2 は元の構造を簡素化し、ネットワークをより効率的にします
  • よりコンパクトなネットワーク構造と適切に設計されたコンポーネントを使用して、セマンティック ブランチ ネットワークが深化され、モデルを高速化するために、より軽い深さの分離可能な畳み込みが使用されます。
  • より効果的な集約層は、セマンティック ブランチと詳細ブランチ間のリンクを強化するように設計されています。

モデル

まず、モデルの全体構造を見てください。

図 1 BiSeNetV2 モデルの構造

 BiSeNetV2 には、主にいくつかの構造が含まれています。

  1. 紫枠(バックボーン)の双方向分岐で、上部がDetail Branch分岐、下部がSemantic Branch分岐です。
  2. Aggregation Layer オレンジ色のボックス内のアグリゲーション レイヤー (Aggregation Layer)。
  3. 黄色のボックス (ブースター) 内の Auxiliary Loss ブランチ。

まずは紫枠の背骨部分をご紹介。

バックボーン詳細ブランチ

For the Detail Branch, the VGG-like network structure is still used. 構造のこの部分は比較的単純で、高速ダウンサンプリングと細分化された特徴マップの取得に使用されます。

コード部分は次のとおりです。

import torch
import torch.nn as nn
class DetailBranch(nn.Module):
    def __init__(self, detail_channels=(64, 64, 128), in_channels=3):
        super(DetailBranch, self).__init__()
        self.detail_branch = nn.ModuleList()

        for i in range(len(detail_channels)):
            if i == 0:
                self.detail_branch.append(
                    nn.Sequential(
                        nn.Conv2d(in_channels, detail_channels[i], 3, stride=2, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU(),

                        nn.Conv2d(detail_channels[i], detail_channels[i], 3, stride=1, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU(),
                    )
                )
            else:
                self.detail_branch.append(
                    nn.Sequential(
                        nn.Conv2d(detail_channels[i-1], detail_channels[i], 3, stride=2, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU(),

                        nn.Conv2d(detail_channels[i], detail_channels[i], 3, stride=1, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU(),

                        nn.Conv2d(detail_channels[i], detail_channels[i], 3, stride=1, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU()
                        )
                    )


    def forward(self, x):
        for stage in self.detail_branch:
            x = stage(x)
        return x

if __name__ == "__main__":
    x = torch.randn(3, 3, 224, 224)
    net = DetailBranch(detail_channels=(64, 64, 128), in_channels=3)
    out = net(x)
    print(out.shape)

バックボーン セマンティック ブランチ

セマンティック ブランチは詳細ブランチと同様であり、主に高レベルのセマンティック情報を取得するために使用されます。このブランチでは、チャネルの数が比較的少なくなります。これは、詳細ブランチでより多くの情報を提供できるためです。高レベルのセマンティック情報を取得するにはコンテキスト依存と大きな受容野が必要であるため、このブランチでは、高速サンプリング戦略を使用して受容野を迅速に拡張し、グローバル平均プーリングを使用してコンテキスト情報を埋め込みます。

著者は、この部分でより精巧な設計を行いました。これには、主に 3 つの部分が含まれます。

  1. 高速ダウンサンプリング用のステム ブロック。
  2. 詳細な情報を取得するための畳み込みには、Gather-and-Expansion Layer (GE Layer) が使用されます。
  3. Context Embedding Block (CE Layer) は、コンテキスト情報を埋め込むために使用されます。

ステムブロックとCEブロックの構造

ステムブロックとCEブロックの構造は比較的シンプルです。

ステムブロックとCEブロックの構造
図2 ステムブロックとCEブロックの構造

コード:

import torch
import torch.nn as nn
import torch.nn.functional as F

class StemBlock(nn.Module):
    def __init__(self, in_channels=3, out_channels=16):
        super(StemBlock, self).__init__()

        self.conv_in = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.conv_branch = nn.Sequential(
            nn.Conv2d(out_channels, out_channels//2, 1),
            nn.BatchNorm2d(out_channels//2),
            nn.ReLU(),            
            nn.Conv2d(out_channels//2, out_channels, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ) 

        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)

        self.fusion = nn.Sequential(
            nn.Conv2d(2*out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv_in(x)

        x_branch = self.conv_branch(x)
        x_downsample = self.pool(x)
        out = torch.cat([x_branch, x_downsample], dim=1)
        out = self.fusion(out)

        return out
        
if __name__ == "__main__":
    x = torch.randn(3, 3, 224, 224)
    net = StemBlock()
    out = net(x)
    print(out.shape)
class CEBlock(nn.Module):
    def __init__(self,in_channels=16, out_channels=16):
        super(CEBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.gap = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            # AdaptiveAvgPool2d 把形状变为(Batch size, N, 1, 1)后,batch size=1不能正常通过BatchNorm2d, 但是batch size>1是可以正常通过的
            # nn.BatchNorm2d(self.in_channels)
            )

        self.conv_gap = nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, 1, stride=1, padding=0),
            # nn.BatchNorm2d(self.out_channels), 同上
            nn.ReLU()
            )

        # Note: in paper here is naive conv2d, no bn-relu
        self.conv_last = nn.Conv2d(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            padding=1)

    def forward(self, x):
        identity = x
        x = self.gap(x)
        x = self.conv_gap(x)
        x = identity + x
        x = self.conv_last(x)
        return x

if __name__ == "__main__":
    x = torch.randn(1, 16, 224, 224)
    net = CEBlock()
    out = net(x)
    print(out.shape)

GE ブロック構造

図 3 GE ブロック構造 (b、c)

GE Block は、ダウンサンプリングを行うかどうかで、ダウンサンプリングを行わない GE Block (b) とダウンサンプリングを行う GE Block の 2 つのモジュールに分けられます。ここでは筆者が MobileNetv2 の逆ボトルネック構造設計を利用している. 計算量を減らすために, 深い分離可能な畳み込みを途中で使用している.

GE ブロックのコードは次のとおりです。

import torch
import torch.nn as nn
class depthwise_separable_conv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class GELayer(nn.Module):
    def __init__(self, in_channels, out_channels, exp_ratio=6, stride=1):
        super(GELayer, self).__init__()
        mid_channel = in_channels * exp_ratio
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU()
            )

        if stride == 1:
            self.dwconv = nn.Sequential(
                # ReLU in ConvModule not shown in paper
                nn.Conv2d(in_channels, mid_channel, 3, stride=stride, padding=1, groups=in_channels),
                nn.BatchNorm2d(mid_channel),
                nn.ReLU(),

                depthwise_separable_conv(mid_channel, mid_channel, stride=1),
                nn.BatchNorm2d(mid_channel),
                )
            self.shortcut = None
        else:
            self.dwconv = nn.Sequential(
                nn.Conv2d(in_channels, mid_channel, 3, stride=1, padding=1, groups=in_channels,bias=False),
                nn.BatchNorm2d(mid_channel),
                nn.ReLU(),
                
                # ReLU in ConvModule not shown in paper
                depthwise_separable_conv(mid_channel, mid_channel, stride=stride),
                nn.BatchNorm2d(mid_channel),       
                
                depthwise_separable_conv(mid_channel, mid_channel, stride=1),
                nn.BatchNorm2d(mid_channel),
            )

            self.shortcut = nn.Sequential(
                depthwise_separable_conv(in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels),

                nn.Conv2d(out_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels),
                )

        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_channel, out_channels, kernel_size=1, stride=1, padding=0,bias=False),
            nn.BatchNorm2d(out_channels)
            )

        self.act = nn.ReLU()

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.dwconv(x)
        x = self.conv2(x)

        if self.shortcut is not None:
            shortcut = self.shortcut(identity)
            x = x + shortcut
        else:
            x = x + identity
        x = self.act(x)
        return x


if __name__ == "__main__":
    x = torch.randn(3, 16, 224, 224)
    net = GELayer(in_channels=16, out_channels=16, stride=2)
    out = net(x)
    print(out.shape)

セマンティック ブランチのコード:

class SemanticBranch(nn.Module):
    def __init__(self, semantic_channels=(16, 32, 64, 128), in_channels=3, exp_ratio=6):
        super(SemanticBranch, self).__init__()
        self.in_channels = in_channels
        self.semantic_channels = semantic_channels
        self.semantic_stages = nn.ModuleList()
        
        for i in range(len(semantic_channels)):
            if i == 0:
                self.semantic_stages.append(StemBlock(self.in_channels, semantic_channels[i]))

            elif i == (len(semantic_channels) - 1):
                self.semantic_stages.append(
                    nn.Sequential(
                        GELayer(semantic_channels[i - 1], semantic_channels[i], exp_ratio, 2),
                        GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1),
                        
                        GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1),
                        GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1)
                        )
                    )

            else:
                self.semantic_stages.append(
                    nn.Sequential(
                        GELayer(semantic_channels[i - 1], semantic_channels[i],
                                exp_ratio, 2),
                        GELayer(semantic_channels[i], semantic_channels[i],
                                exp_ratio, 1)
                                )
                            )

        self.semantic_stages.append(CEBlock(semantic_channels[-1], semantic_channels[-1]))



    def forward(self, x):
        semantic_outs = []
        for semantic_stage in self.semantic_stages:
            x = semantic_stage(x)
            semantic_outs.append(x)
        return semantic_outs

if __name__ == "__main__":
    x = torch.randn(3, 3, 224, 224)
    net = SemanticBranch()
    out = net(x)
    print(out[0].shape)
    print(out[1].shape)
    print(out[2].shape)
    print(out[3].shape)
    print(out[4].shape)


    # from torchsummary import summary
    # summary(net.cuda(), (3, 224, 224))

集約層

Aggregation Layer は、Detail Branch と Semantic Branch の結果を受け取り、図 4 の一連の操作を通じて機能融合を実行します。

図 4 アグリゲーション層の構造

 コード:

import torch
import torch.nn as nn
import torch.nn.functional as F
class AggregationLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AggregationLayer, self).__init__()
        self.Conv_DetailBranch_1 = nn.Sequential(
            depthwise_separable_conv(in_channels, out_channels, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, 1)
        )
        
        self.Conv_DetailBranch_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
        )
        
        self.Conv_SemanticBranch_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True),
            nn.Sigmoid()
        )

        self.Conv_SemanticBranch_2 = nn.Sequential(
            depthwise_separable_conv(in_channels, out_channels, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.Sigmoid()
        )

        self.conv_out = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
        )
        
    def forward(self, Detail_x, Semantic_x):
        DetailBranch_1 = self.Conv_DetailBranch_1(Detail_x)
        DetailBranch_2 = self.Conv_DetailBranch_2(Detail_x)

        SemanticBranch_1 = self.Conv_SemanticBranch_1(Semantic_x)
        SemanticBranch_2 = self.Conv_SemanticBranch_2(Semantic_x)

        out_1 = torch.matmul(DetailBranch_1, SemanticBranch_1)
        out_2 = torch.matmul(DetailBranch_2, SemanticBranch_2)
        out_2 = F.interpolate(out_2, scale_factor=4, mode="bilinear", align_corners=True)

        out = torch.matmul(out_1, out_2)
        out = self.conv_out(out)
        return out

if __name__ == "__main__":
    Detail_x = torch.randn(3, 56, 224, 224)
    Semantic_x = torch.randn(3, 56, 224//4, 224//4)
    net = AggregationLayer(in_channels=56, out_channels=122)
    out = net(Detail_x, Semantic_x)
    print(out.shape)
    

セグメンテーション ヘッド

検出ヘッドの実装は比較的簡単です。

class SegHead(nn.Module):
    def __init__(self, channels, num_classes):
        super().__init__()
        self.cls_seg = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(channels, num_classes, 1),
        )

    def forward(self, x):
        return self.cls_seg(x)

ブースター(補欠)

著者はセマンティック ブランチにいくつかの Auxiliary Loss ブランチを導入し、Auxiliary Loss が集中している組み合わせのパフォーマンスを比較し、次の結果を得ました。

BiSeNetV2 の実装とその Camvid への適用

BiSeNetV2 の実装

import torch
import torch.nn as nn
import torch.nn.functional as F

class StemBlock(nn.Module):
    def __init__(self, in_channels=3, out_channels=16):
        super(StemBlock, self).__init__()

        self.conv_in = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.conv_branch = nn.Sequential(
            nn.Conv2d(out_channels, out_channels//2, 1),
            nn.BatchNorm2d(out_channels//2),
            nn.ReLU(),            
            nn.Conv2d(out_channels//2, out_channels, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ) 

        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)

        self.fusion = nn.Sequential(
            nn.Conv2d(2*out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv_in(x)

        x_branch = self.conv_branch(x)
        x_downsample = self.pool(x)
        out = torch.cat([x_branch, x_downsample], dim=1)
        out = self.fusion(out)

        return out
class depthwise_separable_conv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class GELayer(nn.Module):
    def __init__(self, in_channels, out_channels, exp_ratio=6, stride=1):
        super(GELayer, self).__init__()
        mid_channel = in_channels * exp_ratio
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1,padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU()
            )

        if stride == 1:
            self.dwconv = nn.Sequential(
                # ReLU in ConvModule not shown in paper
                nn.Conv2d(in_channels, mid_channel, 3, stride=stride, padding=1, groups=in_channels),
                nn.BatchNorm2d(mid_channel),
                nn.ReLU(),

                depthwise_separable_conv(mid_channel, mid_channel, stride=1),
                nn.BatchNorm2d(mid_channel),
                )
            self.shortcut = None
        else:
            self.dwconv = nn.Sequential(
                nn.Conv2d(in_channels, mid_channel, 3, stride=1, padding=1, groups=in_channels,bias=False),
                nn.BatchNorm2d(mid_channel),
                nn.ReLU(),
                
                # ReLU in ConvModule not shown in paper
                depthwise_separable_conv(mid_channel, mid_channel, stride=stride),
                nn.BatchNorm2d(mid_channel),       
                
                depthwise_separable_conv(mid_channel, mid_channel, stride=1),
                nn.BatchNorm2d(mid_channel),
            )

            self.shortcut = nn.Sequential(
                depthwise_separable_conv(in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels),

                nn.Conv2d(out_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels),
                )

        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_channel, out_channels, kernel_size=1, stride=1, padding=0,bias=False),
            nn.BatchNorm2d(out_channels)
            )

        self.act = nn.ReLU()

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.dwconv(x)
        x = self.conv2(x)

        if self.shortcut is not None:
            shortcut = self.shortcut(identity)
            x = x + shortcut
        else:
            x = x + identity
        x = self.act(x)
        return x

class CEBlock(nn.Module):
    def __init__(self,in_channels=16, out_channels=16):
        super(CEBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.gap = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            # AdaptiveAvgPool2d 把形状变为(Batch size, N, 1, 1)后,batch size=1不能正常通过BatchNorm2d, 但是batch size>1是可以正常通过的。如果想开启BatchNorm,训练时batch size>1即可,测试时使用model.eval()即不会报错。
            # nn.BatchNorm2d(self.in_channels)
            )

        self.conv_gap = nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, 1, stride=1, padding=0),
            # nn.BatchNorm2d(self.out_channels), 同上
            nn.ReLU()
            )

        # Note: in paper here is naive conv2d, no bn-relu
        self.conv_last = nn.Conv2d(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            kernel_size=3,
            stride=1,
            padding=1)

    def forward(self, x):
        identity = x
        x = self.gap(x)
        x = self.conv_gap(x)
        x = identity + x
        x = self.conv_last(x)
        return x

class DetailBranch(nn.Module):
    def __init__(self, detail_channels=(64, 64, 128), in_channels=3):
        super(DetailBranch, self).__init__()
        self.detail_branch = nn.ModuleList()

        for i in range(len(detail_channels)):
            if i == 0:
                self.detail_branch.append(
                    nn.Sequential(
                        nn.Conv2d(in_channels, detail_channels[i], 3, stride=2, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU(),

                        nn.Conv2d(detail_channels[i], detail_channels[i], 3, stride=1, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU(),
                    )
                )
            else:
                self.detail_branch.append(
                    nn.Sequential(
                        nn.Conv2d(detail_channels[i-1], detail_channels[i], 3, stride=2, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU(),

                        nn.Conv2d(detail_channels[i], detail_channels[i], 3, stride=1, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU(),

                        nn.Conv2d(detail_channels[i], detail_channels[i], 3, stride=1, padding=1),
                        nn.BatchNorm2d(detail_channels[i]),
                        nn.ReLU()
                        )
                    )


    def forward(self, x):
        for stage in self.detail_branch:
            x = stage(x)
        return x

class SemanticBranch(nn.Module):
    def __init__(self, semantic_channels=(16, 32, 64, 128), in_channels=3, exp_ratio=6):
        super(SemanticBranch, self).__init__()
        self.in_channels = in_channels
        self.semantic_channels = semantic_channels
        self.semantic_stages = nn.ModuleList()
        
        for i in range(len(semantic_channels)):
            if i == 0:
                self.semantic_stages.append(StemBlock(self.in_channels, semantic_channels[i]))

            elif i == (len(semantic_channels) - 1):
                self.semantic_stages.append(
                    nn.Sequential(
                        GELayer(semantic_channels[i - 1], semantic_channels[i], exp_ratio, 2),
                        GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1),
                        
                        GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1),
                        GELayer(semantic_channels[i], semantic_channels[i], exp_ratio, 1)
                        )
                    )

            else:
                self.semantic_stages.append(
                    nn.Sequential(
                        GELayer(semantic_channels[i - 1], semantic_channels[i],
                                exp_ratio, 2),
                        GELayer(semantic_channels[i], semantic_channels[i],
                                exp_ratio, 1)
                                )
                            )

        self.semantic_stages.append(CEBlock(semantic_channels[-1], semantic_channels[-1]))



    def forward(self, x):
        semantic_outs = []
        for semantic_stage in self.semantic_stages:
            x = semantic_stage(x)
            semantic_outs.append(x)
        return semantic_outs

class AggregationLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AggregationLayer, self).__init__()
        self.Conv_DetailBranch_1 = nn.Sequential(
            depthwise_separable_conv(in_channels, out_channels, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, 1)
        )
        
        self.Conv_DetailBranch_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
        )
        
        self.Conv_SemanticBranch_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True),
            nn.Sigmoid()
        )

        self.Conv_SemanticBranch_2 = nn.Sequential(
            depthwise_separable_conv(in_channels, out_channels, stride=1),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.Sigmoid()
        )

        self.conv_out = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
        )
        
    def forward(self, Detail_x, Semantic_x):
        DetailBranch_1 = self.Conv_DetailBranch_1(Detail_x)
        DetailBranch_2 = self.Conv_DetailBranch_2(Detail_x)

        SemanticBranch_1 = self.Conv_SemanticBranch_1(Semantic_x)
        SemanticBranch_2 = self.Conv_SemanticBranch_2(Semantic_x)

        out_1 = torch.matmul(DetailBranch_1, SemanticBranch_1)
        out_2 = torch.matmul(DetailBranch_2, SemanticBranch_2)
        out_2 = F.interpolate(out_2, scale_factor=4, mode="bilinear", align_corners=True)

        out = torch.matmul(out_1, out_2)
        out = self.conv_out(out)
        return out

class SegHead(nn.Module):
    def __init__(self, channels, num_classes):
        super().__init__()
        self.cls_seg = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(channels, num_classes, 1),
        )

    def forward(self, x):
        return self.cls_seg(x)

class BiSeNetV2(nn.Module):
    def __init__(self,in_channels=3,
                detail_channels=(64, 64, 128), 
                semantic_channels=(16, 32, 64, 128), 
                semantic_expansion_ratio=6,
                aggregation_channels=128,
                out_indices=(0, 1, 2, 3, 4),
                num_classes = 3):
        super(BiSeNetV2, self).__init__()

        self.in_channels = in_channels
        self.detail_channels = detail_channels
        self.semantic_expansion_ratio = semantic_expansion_ratio
        self.semantic_channels = semantic_channels
        self.aggregation_channels = aggregation_channels
        self.out_indices = out_indices
        self.num_classes = num_classes
        
        self.detail = DetailBranch(detail_channels=self.detail_channels, in_channels=self.in_channels)
        self.semantic = SemanticBranch(semantic_channels=self.semantic_channels, in_channels=self.in_channels,exp_ratio=self.semantic_expansion_ratio)
        self.AggregationLayer = AggregationLayer(in_channels=self.aggregation_channels, out_channels=self.aggregation_channels)


        self.seg_head_aggre = SegHead(semantic_channels[-1], self.num_classes)
        self.seg_heads = nn.ModuleList()
        self.seg_heads.append(self.seg_head_aggre)
        for channel in semantic_channels:
            self.seg_heads.append(SegHead(channel, self.num_classes))



    def forward(self, x):
        _, _, h, w = x.size()
        x_detail = self.detail(x)
        x_semantic_lst = self.semantic(x)
        x_head = self.AggregationLayer(x_detail, x_semantic_lst[-1])
        outs = [x_head] + x_semantic_lst[:-1]
        outs = [outs[i] for i in self.out_indices]

        out = tuple(outs)

        seg_out = []
        for index, stage in enumerate(self.seg_heads):
            seg_out.append(F.interpolate(stage(out[index]),size=(h,w), mode="bilinear", align_corners=True))
        return seg_out

Camvid データセット

# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
 
torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(448, 448),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
 
    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
# 设置数据集路径
DATA_DIR = r'database/camvid/camvid/' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    
train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)
 
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True,drop_last=True)

訓練

model = BiSeNetV2(num_classes=33)


from d2l import torch as d2l
from tqdm import tqdm
import pandas as pd
import monai

# training loop 100 epochs
epochs_num = 100
# 选用SGD优化器来训练
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
schedule = monai.optimizers.LinearLR(optimizer, end_lr=0.05, num_iter=int(epochs_num*0.75))

# 损失函数选用多分类交叉熵损失函数
lossf = nn.CrossEntropyLoss(ignore_index=255)


def evaluate_accuracy_gpu(net, data_iter, device=None):
    if isinstance(net, nn.Module):
        net.eval()  # Set the model to evaluation mode
        if not device:
            device = next(iter(net.parameters())).device
    # No. of correct predictions, no. of predictions
    metric = d2l.Accumulator(2)

    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # Required for BERT Fine-tuning (to be covered later)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            output = net(X)
            pred = output[0]
            metric.add(d2l.accuracy(pred, y), d2l.size(y))
    return metric[0] / metric[1]


# 训练函数
def train_ch13(net, train_iter, test_iter, loss, optimizer, num_epochs, schedule, swa_start=swa_start, devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    # 用来保存一些训练参数

    loss_list = []
    train_acc_list = []
    test_acc_list = []
    epochs_list = []
    time_list = []
    lr_list = []
    

    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (X, labels) in enumerate(train_iter):
            timer.start()

            if isinstance(X, list):
                X = [x.to(devices[0]) for x in X]
            else:
                X = X.to(devices[0])
            gt = labels.long().to(devices[0])

            net.train()
            optimizer.zero_grad()
            result = net(X)
            pred = result[0]
            seg_loss = loss(result[0], gt)

            aux_loss_1 = loss(result[1], gt)
            aux_loss_2 = loss(result[2], gt)
            aux_loss_3 = loss(result[3], gt)
            aux_loss_4 = loss(result[4], gt)


            loss_sum = seg_loss + 0.2*aux_loss_1 + 0.2*aux_loss_2 + 0.2*aux_loss_3 + 0.2*aux_loss_4
            l = loss_sum
            loss_sum.sum().backward()
            optimizer.step()

            acc = d2l.accuracy(pred, gt)
            metric.add(l, acc, labels.shape[0], labels.numel())

            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3], None))
                
        if optimizer.state_dict()['param_groups'][0]['lr']>0.05:
            schedule.step()

        test_acc = evaluate_accuracy_gpu(net, test_iter)
        
        animator.add(epoch + 1, (None, None, test_acc))

        print(f"epoch {epoch+1}/{epochs_num} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- lr {optimizer.state_dict()['param_groups'][0]['lr']} --- cost time {timer.sum()}")
        
        #---------保存训练数据---------------
        df = pd.DataFrame()
        loss_list.append(metric[0] / metric[2])
        train_acc_list.append(metric[1] / metric[3])
        test_acc_list.append(test_acc)
        epochs_list.append(epoch+1)
        time_list.append(timer.sum())
        lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
        
        df['epoch'] = epochs_list
        df['loss'] = loss_list
        df['train_acc'] = train_acc_list
        df['test_acc'] = test_acc_list
        df["lr"] = lr_list
        df['time'] = time_list
        
        df.to_excel("savefile/BiseNetv2_camvid.xlsx")
        #----------------保存模型------------------- 
        if np.mod(epoch+1, 5) == 0:
            torch.save(net.state_dict(), f'checkpoints/BiseNetv2_{epoch+1}.pth')

    # 保存下最后的model
    torch.save(net.state_dict(), f'checkpoints/BiseNetv2_last.pth')

train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num, schedule=schedule)

結果

おすすめ

転載: blog.csdn.net/yumaomi/article/details/125643372