画像セグメンテーション モデル - Unet とコード

Unet ネットワーク構造:

(1) UNet ネットワークは U 字型の構造です。完全に接続された演算を行わずに完全な畳み込みニューラル ネットワークを使用する

(2) 左側のエンコーダ部分は特徴抽出ネットワークです。ダウンサンプリングには conv と pooling を使用します。

(3) 右側のデコーダ部分は特徴融合ネットワークです。右側のアップサンプリングによって生成された特徴マップと左側のダウンサンプリングされた特徴マップがチャネル次元で連結されます。(画像のサイズは BCHW、それぞれバッチサイズ、チャネル、高さ、幅です)

アップサンプリングの目的: プーリング層は画像の幅と高さを半分にするため、画像情報が失われ、解像度が低下します。アップサンプリングにより、画像の解像度が向上し、高レベルの抽象的な特徴が保持され、左側の低レベルの表面特徴の高解像度画像と結合できます。

アップサンプリング方法: 単純な補間アップサンプリング方法の代わりに転置畳み込み nn.ConvTranspose2d() を使用すると、同じ効果が得られ、ネットワークを強化できます。

(4) 最後に、2 つの 3*3 畳み込み演算の後、1*1 畳み込みカーネルを使用すると、出力チャネルの次元は分割されるカテゴリ num_classes の数になり、生成される次元は (B, num_classes, H, W) になります。 )の特徴的な写真。

コード: 

import torch
import torch.nn as nn
import torch.nn.functional as F


def X2conv(in_channel,out_channel):
    """连续两个3*3卷积"""
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
        nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU())

class DownsampleLayer(nn.Module):
    """
    下采样层
    """
    def __init__(self,in_channel,out_channel):
        super(DownsampleLayer, self).__init__()
        self.x2conv=X2conv(in_channel,out_channel)
        self.pool=nn.MaxPool2d(kernel_size=2,ceil_mode=True)

    def forward(self,x):
        """
        :param x:上一层pool后的特征
        :return: out_1转入右侧(待拼接),out_1输入到下一层,
        """
        out_1=self.x2conv(x)
        out=self.pool(out_1)
        return out_1,out

class UpSampleLayer(nn.Module):
    """
    上采样层
    """
    def __init__(self,in_channel,out_channel):

        super(UpSampleLayer, self).__init__()
        self.x2conv = X2conv(in_channel, out_channel)
        self.upsample=nn.ConvTranspose2d\ 
(in_channels=out_channel,out_channels=out_channel//2,kernel_size=3,stride=2,padding=1)

    def forward(self,x,out):
        '''
        :param x: decoder中:输入层特征,经过x2conv与上采样upsample,然后拼接
        :param out:左侧encoder层中特征(与右侧上采样层进行cat)
        :return:
        '''
        x=self.x2conv(x)
        x=self.upsample(x)

        # x.shape中H W 应与 out.shape中的H W相同
        if (x.size(2) != out.size(2)) or (x.size(3) != out.size(3)):
            # 将右侧特征H W大小插值变为左侧特征H W大小
            x = F.interpolate(x, size=(out.size(2), out.size(3)),
                            mode="bilinear", align_corners=True)


        # Concatenate(在channel维度)
        cat_out = torch.cat([x, out], dim=1)
        return cat_out

class UNet(nn.Module):
    """
    UNet模型,num_classes为分割类别数
    """
    def __init__(self,num_classes):
        super(UNet, self).__init__()
        #下采样
        self.d1=DownsampleLayer(3,64) #3-64
        self.d2=DownsampleLayer(64,128)#64-128
        self.d3=DownsampleLayer(128,256)#128-256
        self.d4=DownsampleLayer(256,512)#256-512

        #上采样
        self.u1=UpSampleLayer(512,1024)#512-1024-512
        self.u2=UpSampleLayer(1024,512)#1024-512-256
        self.u3=UpSampleLayer(512,256)#512-256-128
        self.u4=UpSampleLayer(256,128)#256-128-64

        #输出:经过一个二层3*3卷积 + 1个1*1卷积
        self.x2conv=X2conv(128,64)
        self.final_conv=nn.Conv2d(64,num_classes,kernel_size=1)  # 最后一个卷积层的输出通道数为分割的类别数
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

    def forward(self,x):
        # 下采样层
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)

        # 上采样层 拼接
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)

        # 最后的三层卷积
        out=self.x2conv(out8)
        out=self.final_conv(out)
        return out

if __name__ == "__main__":
    img = torch.randn((2, 3, 360, 480))  # 正态分布初始化

    model = UNet(num_classes=16)

    output = model(img)
    print(output.shape)

おすすめ

転載: blog.csdn.net/m0_63077499/article/details/127089803