Modelo de Segmentación de Imágenes-Unet y Código

Estructura de la red Unet:

(1) La red UNet es una estructura en forma de U. Uso de una red neuronal totalmente convolucional sin operaciones totalmente conectadas

(2) La parte del codificador de la izquierda es la red de extracción de características : use conv y pooling para reducir la resolución

(3) La parte del decodificador de la derecha es una red de fusión de funciones : el mapa de funciones generado mediante el muestreo ascendente en el lado derecho y el mapa de funciones muestreado descendentemente en el lado izquierdo se concatenan en la dimensión del canal. (Las dimensiones de la imagen son: BCHW, respectivamente tamaño de lote, canal, altura, ancho)

El propósito del muestreo superior : la capa de agrupación reduce a la mitad el ancho y el alto de la imagen, lo que perderá información de la imagen y reducirá la resolución. El sobremuestreo puede mejorar la resolución de la imagen y conservar las características abstractas de alto nivel, y luego unirlas con la imagen de alta resolución de las características de la superficie de bajo nivel a la izquierda.

Método de muestreo superior : el uso de convolución transpuesta nn.ConvTranspose2d() en lugar de un método de muestreo superior de interpolación simple puede lograr el mismo efecto y profundizar la red.

(4) Finalmente, después de dos operaciones de convolución 3*3, y luego usar un kernel de convolución 1*1, la dimensión del canal de salida es el número de categorías num_classes a dividir, y la dimensión generada es (B, num_classes, H, W ) cuenta con imagen.

código: 

import torch
import torch.nn as nn
import torch.nn.functional as F


def X2conv(in_channel,out_channel):
    """连续两个3*3卷积"""
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
        nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU())

class DownsampleLayer(nn.Module):
    """
    下采样层
    """
    def __init__(self,in_channel,out_channel):
        super(DownsampleLayer, self).__init__()
        self.x2conv=X2conv(in_channel,out_channel)
        self.pool=nn.MaxPool2d(kernel_size=2,ceil_mode=True)

    def forward(self,x):
        """
        :param x:上一层pool后的特征
        :return: out_1转入右侧(待拼接),out_1输入到下一层,
        """
        out_1=self.x2conv(x)
        out=self.pool(out_1)
        return out_1,out

class UpSampleLayer(nn.Module):
    """
    上采样层
    """
    def __init__(self,in_channel,out_channel):

        super(UpSampleLayer, self).__init__()
        self.x2conv = X2conv(in_channel, out_channel)
        self.upsample=nn.ConvTranspose2d\ 
(in_channels=out_channel,out_channels=out_channel//2,kernel_size=3,stride=2,padding=1)

    def forward(self,x,out):
        '''
        :param x: decoder中:输入层特征,经过x2conv与上采样upsample,然后拼接
        :param out:左侧encoder层中特征(与右侧上采样层进行cat)
        :return:
        '''
        x=self.x2conv(x)
        x=self.upsample(x)

        # x.shape中H W 应与 out.shape中的H W相同
        if (x.size(2) != out.size(2)) or (x.size(3) != out.size(3)):
            # 将右侧特征H W大小插值变为左侧特征H W大小
            x = F.interpolate(x, size=(out.size(2), out.size(3)),
                            mode="bilinear", align_corners=True)


        # Concatenate(在channel维度)
        cat_out = torch.cat([x, out], dim=1)
        return cat_out

class UNet(nn.Module):
    """
    UNet模型,num_classes为分割类别数
    """
    def __init__(self,num_classes):
        super(UNet, self).__init__()
        #下采样
        self.d1=DownsampleLayer(3,64) #3-64
        self.d2=DownsampleLayer(64,128)#64-128
        self.d3=DownsampleLayer(128,256)#128-256
        self.d4=DownsampleLayer(256,512)#256-512

        #上采样
        self.u1=UpSampleLayer(512,1024)#512-1024-512
        self.u2=UpSampleLayer(1024,512)#1024-512-256
        self.u3=UpSampleLayer(512,256)#512-256-128
        self.u4=UpSampleLayer(256,128)#256-128-64

        #输出:经过一个二层3*3卷积 + 1个1*1卷积
        self.x2conv=X2conv(128,64)
        self.final_conv=nn.Conv2d(64,num_classes,kernel_size=1)  # 最后一个卷积层的输出通道数为分割的类别数
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

    def forward(self,x):
        # 下采样层
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)

        # 上采样层 拼接
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)

        # 最后的三层卷积
        out=self.x2conv(out8)
        out=self.final_conv(out)
        return out

if __name__ == "__main__":
    img = torch.randn((2, 3, 360, 480))  # 正态分布初始化

    model = UNet(num_classes=16)

    output = model(img)
    print(output.shape)

Supongo que te gusta

Origin blog.csdn.net/m0_63077499/article/details/127089803
Recomendado
Clasificación