Image Segmentation Model-Unet and Code

Unet network structure:

(1) The UNet network is a U-shaped structure. Using a fully convolutional neural network without fully connected operations

(2) The encoder part on the left is the feature extraction network : use conv and pooling for downsampling

(3) The decoder part on the right is a feature fusion network : the feature map generated by upsampling on the right side and the feature map downsampled on the left side are concatenate in the channel dimension. (The dimensions of the picture are: BCHW, respectively batchsize, channel, height, width)

The purpose of upsampling : The pooling layer halves the width and height of the image, which will lose image information and reduce resolution. Upsampling can improve the image resolution and retain high-level abstract features, and then stitch with the high-resolution image of the low-level surface features on the left.

Upsampling method : Using transposed convolution nn.ConvTranspose2d() instead of a simple interpolation upsampling method can achieve the same effect and deepen the network.

(4) Finally, after two 3*3 convolution operations, and then use a 1*1 convolution kernel, the output channel dimension is the number of categories num_classes to be divided, and the generated dimension is (B, num_classes, H, W) features picture.

code: 

import torch
import torch.nn as nn
import torch.nn.functional as F


def X2conv(in_channel,out_channel):
    """连续两个3*3卷积"""
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
        nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU())

class DownsampleLayer(nn.Module):
    """
    下采样层
    """
    def __init__(self,in_channel,out_channel):
        super(DownsampleLayer, self).__init__()
        self.x2conv=X2conv(in_channel,out_channel)
        self.pool=nn.MaxPool2d(kernel_size=2,ceil_mode=True)

    def forward(self,x):
        """
        :param x:上一层pool后的特征
        :return: out_1转入右侧(待拼接),out_1输入到下一层,
        """
        out_1=self.x2conv(x)
        out=self.pool(out_1)
        return out_1,out

class UpSampleLayer(nn.Module):
    """
    上采样层
    """
    def __init__(self,in_channel,out_channel):

        super(UpSampleLayer, self).__init__()
        self.x2conv = X2conv(in_channel, out_channel)
        self.upsample=nn.ConvTranspose2d\ 
(in_channels=out_channel,out_channels=out_channel//2,kernel_size=3,stride=2,padding=1)

    def forward(self,x,out):
        '''
        :param x: decoder中:输入层特征,经过x2conv与上采样upsample,然后拼接
        :param out:左侧encoder层中特征(与右侧上采样层进行cat)
        :return:
        '''
        x=self.x2conv(x)
        x=self.upsample(x)

        # x.shape中H W 应与 out.shape中的H W相同
        if (x.size(2) != out.size(2)) or (x.size(3) != out.size(3)):
            # 将右侧特征H W大小插值变为左侧特征H W大小
            x = F.interpolate(x, size=(out.size(2), out.size(3)),
                            mode="bilinear", align_corners=True)


        # Concatenate(在channel维度)
        cat_out = torch.cat([x, out], dim=1)
        return cat_out

class UNet(nn.Module):
    """
    UNet模型,num_classes为分割类别数
    """
    def __init__(self,num_classes):
        super(UNet, self).__init__()
        #下采样
        self.d1=DownsampleLayer(3,64) #3-64
        self.d2=DownsampleLayer(64,128)#64-128
        self.d3=DownsampleLayer(128,256)#128-256
        self.d4=DownsampleLayer(256,512)#256-512

        #上采样
        self.u1=UpSampleLayer(512,1024)#512-1024-512
        self.u2=UpSampleLayer(1024,512)#1024-512-256
        self.u3=UpSampleLayer(512,256)#512-256-128
        self.u4=UpSampleLayer(256,128)#256-128-64

        #输出:经过一个二层3*3卷积 + 1个1*1卷积
        self.x2conv=X2conv(128,64)
        self.final_conv=nn.Conv2d(64,num_classes,kernel_size=1)  # 最后一个卷积层的输出通道数为分割的类别数
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

    def forward(self,x):
        # 下采样层
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)

        # 上采样层 拼接
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)

        # 最后的三层卷积
        out=self.x2conv(out8)
        out=self.final_conv(out)
        return out

if __name__ == "__main__":
    img = torch.randn((2, 3, 360, 480))  # 正态分布初始化

    model = UNet(num_classes=16)

    output = model(img)
    print(output.shape)

Guess you like

Origin blog.csdn.net/m0_63077499/article/details/127089803