Modèle de segmentation d'image - Unet et code

Structure du réseau Unet :

(1) Le réseau UNet est une structure en U. Utilisation d'un réseau de neurones entièrement convolutionnel sans opérations entièrement connectées

(2) La partie encodeur à gauche est le réseau d'extraction de caractéristiques : utilisez conv et pooling pour le sous-échantillonnage

(3) La partie décodeur de droite est un réseau de fusion de caractéristiques : la carte de caractéristiques générée par le suréchantillonnage de droite et la carte de caractéristiques du sous-échantillonnage de gauche sont concaténées dans la dimension du canal. (Les dimensions de l'image sont : BCHW, respectivement batchsize, channel, height, width)

Le but du suréchantillonnage : la couche de regroupement divise par deux la largeur et la hauteur de l'image, ce qui perdra des informations sur l'image et réduira la résolution. Le suréchantillonnage peut améliorer la résolution de l'image et conserver les caractéristiques abstraites de haut niveau, puis assembler avec l'image haute résolution des caractéristiques de surface de bas niveau sur la gauche.

Méthode de suréchantillonnage : L'utilisation de la convolution transposée nn.ConvTranspose2d() au lieu d'une simple méthode de suréchantillonnage par interpolation peut obtenir le même effet et approfondir le réseau.

(4) Enfin, après deux opérations de convolution 3*3, puis en utilisant un noyau de convolution 1*1, la dimension du canal de sortie est le nombre de catégories num_classes à diviser, et la dimension générée est (B, num_classes, H, W ) présente une image.

code: 

import torch
import torch.nn as nn
import torch.nn.functional as F


def X2conv(in_channel,out_channel):
    """连续两个3*3卷积"""
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
        nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU())

class DownsampleLayer(nn.Module):
    """
    下采样层
    """
    def __init__(self,in_channel,out_channel):
        super(DownsampleLayer, self).__init__()
        self.x2conv=X2conv(in_channel,out_channel)
        self.pool=nn.MaxPool2d(kernel_size=2,ceil_mode=True)

    def forward(self,x):
        """
        :param x:上一层pool后的特征
        :return: out_1转入右侧(待拼接),out_1输入到下一层,
        """
        out_1=self.x2conv(x)
        out=self.pool(out_1)
        return out_1,out

class UpSampleLayer(nn.Module):
    """
    上采样层
    """
    def __init__(self,in_channel,out_channel):

        super(UpSampleLayer, self).__init__()
        self.x2conv = X2conv(in_channel, out_channel)
        self.upsample=nn.ConvTranspose2d\ 
(in_channels=out_channel,out_channels=out_channel//2,kernel_size=3,stride=2,padding=1)

    def forward(self,x,out):
        '''
        :param x: decoder中:输入层特征,经过x2conv与上采样upsample,然后拼接
        :param out:左侧encoder层中特征(与右侧上采样层进行cat)
        :return:
        '''
        x=self.x2conv(x)
        x=self.upsample(x)

        # x.shape中H W 应与 out.shape中的H W相同
        if (x.size(2) != out.size(2)) or (x.size(3) != out.size(3)):
            # 将右侧特征H W大小插值变为左侧特征H W大小
            x = F.interpolate(x, size=(out.size(2), out.size(3)),
                            mode="bilinear", align_corners=True)


        # Concatenate(在channel维度)
        cat_out = torch.cat([x, out], dim=1)
        return cat_out

class UNet(nn.Module):
    """
    UNet模型,num_classes为分割类别数
    """
    def __init__(self,num_classes):
        super(UNet, self).__init__()
        #下采样
        self.d1=DownsampleLayer(3,64) #3-64
        self.d2=DownsampleLayer(64,128)#64-128
        self.d3=DownsampleLayer(128,256)#128-256
        self.d4=DownsampleLayer(256,512)#256-512

        #上采样
        self.u1=UpSampleLayer(512,1024)#512-1024-512
        self.u2=UpSampleLayer(1024,512)#1024-512-256
        self.u3=UpSampleLayer(512,256)#512-256-128
        self.u4=UpSampleLayer(256,128)#256-128-64

        #输出:经过一个二层3*3卷积 + 1个1*1卷积
        self.x2conv=X2conv(128,64)
        self.final_conv=nn.Conv2d(64,num_classes,kernel_size=1)  # 最后一个卷积层的输出通道数为分割的类别数
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

    def forward(self,x):
        # 下采样层
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)

        # 上采样层 拼接
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)

        # 最后的三层卷积
        out=self.x2conv(out8)
        out=self.final_conv(out)
        return out

if __name__ == "__main__":
    img = torch.randn((2, 3, 360, 480))  # 正态分布初始化

    model = UNet(num_classes=16)

    output = model(img)
    print(output.shape)

Je suppose que tu aimes

Origine blog.csdn.net/m0_63077499/article/details/127089803
conseillé
Classement