Image segmentation - U-Net implementation

Semantic Segmentation: It is an important branch of image processing and machine vision. Different from the classification task, semantic segmentation needs to judge the category of each pixel in the image for precise segmentation. Semantic segmentation is currently widely used in the fields of automatic driving, automatic map matting, and medical imaging. - Semantic segmentation is a classification problem!

Unet can be said to be the most commonly used and simplest segmentation model. It is simple, efficient, easy to understand, easy to build, and can be trained from small data sets. The main contribution of UNet is in the U-shaped structure, which allows it to use fewer training pictures while the accuracy of segmentation is not bad. The network structure of UNet is as follows:

The unet network is very simple, the first half is feature extraction, and the second half is upsampling. In some literatures, this structure is called encoder-decoder structure (self-encoding: label is itself, codec structure: label is mask map), because the overall structure of the network is a larger English letter U, so it is called U-net.

  • Encoder: The left half consists of two 3x3 convolutional layers (RELU) plus a 2x2 maxpooling layer to form a downsampling module (as can be seen in the following code);
  • Decoder: The right half is composed of an upsampled convolutional layer (deconvolutional layer) + feature splicing concat + two 3x3 convolutional layers (ReLU) repeatedly (as can be seen in the code);

This structure is to convolve and pool the picture first. In the Unet paper, it is pooled 4 times. For example, if the picture is 224x224 at the beginning, it will become 112x112, 56x56, 28x28, and 14x14. Features of four different sizes . Then we do upsampling or deconvolution on the 14x14 feature map to get a 28x28 feature map. This 28x28 feature map is spliced ​​concat with the previous 28x28 feature map on the channel, and then the spliced ​​feature map is rolled. Product and upsampling to obtain a 56x56 feature map, and then splicing and convolving with the previous 56x56 features, and then upsampling. After four times of upsampling, a 224x224 prediction structure with the same size as the input image can be obtained.

At that time, Unet used splicing as a fusion method of feature maps compared to the FCN network proposed earlier .

  • FCN fuses features by adding pixel values ​​corresponding to feature maps;
  • U-net splicing through the number of channels can form thicker features, of course, this will consume more video memory;

The benefits of Unet: The deeper the network layer is, the feature map obtained has a larger field of view. The shallow convolution focuses on texture features, and the deep network focuses on the essential features, so deep and shallow features have their own meanings; Another point is that the edge of the feature map with a larger size obtained through deconvolution lacks information. After all, every time downsampling extracts features, it will inevitably lose some edge features, and the lost features cannot be obtained from the above. Sampling is retrieved, so through the splicing of features, a retrieval of edge features is realized.

Code example:

import torch
import torch.nn as nn
from torch.nn.functional import interpolate


# unet
#DoubleConv
class CNNlayer(nn.Module):
    def __init__(self, c_in, c_Out):
        super(CNNlayer, self).__init__()
        self.layer = nn.Sequential(
            #设置填充模式为'reflect',在高和宽维度上两边各填充1个单位
            nn.Conv2d(c_in, c_Out, 3, 1, padding=1, padding_mode="reflect", bias=False),
            nn.BatchNorm2d(c_Out),
            nn.LeakyReLU(),
            nn.Dropout2d(0.3),

            nn.Conv2d(c_Out, c_Out, 3, 1, 1, padding_mode="reflect", bias=False),
            nn.BatchNorm2d(c_Out),
            nn.LeakyReLU(),
            nn.Dropout2d(0.4)
        )

    def forward(self, x):
        return self.layer(x)

# 下采样(使用最大池化)---降噪能力较强
class DownSampling(nn.Module):
    def __init__(self):
        super(DownSampling, self).__init__()
        self.layer = nn.Sequential(
            nn.MaxPool2d(2)
        )

    def forward(self, x):
        return self.layer(x)
# #2:使用步长为2的卷积做下采样
# class DownSampling(nn.Module):
#     def __init__(self,C):
#         super(DownSampling, self).__init__()
#         self.layer=nn.Sequential(
#             nn.Conv2d(C,C,3,2,1,padding_mode="reflect"),
#             nn.LeakyReLU(),
#             nn.BatchNorm2d(C)
#         )
#     def forward(self,x):
#         return self.layer(x)

# 上采样+多尺度特征图融合concate
class UpSampling(nn.Module):
    def __init__(self, c):
        super(UpSampling, self).__init__()
        # 特征图大小扩大两倍,通道数减半
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels=c, out_channels=c // 2, kernel_size=3, stride=1, padding=1, padding_mode="reflect", bias=False),
            nn.BatchNorm2d(c // 2),
            nn.LeakyReLU(),
        )
    def forward(self, x, r):
        # 使用临近插值法进行上采样
        up = interpolate(x, scale_factor=2, mode="nearest")  # 特征图放大2倍,通道数不变
        x = self.layer(up)  # 通道数减半,大小不变
        # 通道拼接(cat)
        out=torch.cat((x, r), dim=1)#通道合并,大小不变
        return  out


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        #4次下采样
        self.C1 = CNNlayer(3, 64)  # 1,64,256,256
        self.D1=DownSampling()#1,64,128,128

        self.C2=CNNlayer(64,128)   #1, 128, 128, 128
        self.D2=DownSampling()#1,128,64,64

        self.C3=CNNlayer(128,256)   #1,256,64,64
        self.D3=DownSampling()#1,256,32,32

        self.C4=CNNlayer(256,512)   #1,512,32,32
        self.D4=DownSampling()#1,512,16,16

        #middle
        self.C5_ground=CNNlayer(512,1024)  #1,1024,16,16

        #4次上采样+concate
        self.U1=UpSampling(1024)#1,1024,32,32
        self.C6=CNNlayer(1024,512)#1,512,32,32

        self.U2 = UpSampling(512)#1,512,64,64
        self.C7 = CNNlayer(512, 256)#1,256,64,64

        self.U3 = UpSampling(256)#1,256,128,128
        self.C8 = CNNlayer(256, 128)#1,128,128,128

        self.U4 = UpSampling(128)#1,128,256,256
        self.C9 = CNNlayer(128, 64)#1,64,256,256

        #输出层64->2(是一个二分类问题,输出为背景+前景)
        self.Pre=nn.Conv2d(in_channels=64,out_channels=2,kernel_size=3,stride=1,padding=1)

    def forward(self, x):

        # 下采样部分
        R1 = self.C1(x)
        R2=self.C2(self.D1(R1))
        R3 = self.C3(self.D2(R2))
        R4 = self.C4(self.D3(R3))
        R5= self.C5_ground(self.D4(R4))

        #上采样部分
        O1 =self.C6(self.U1(R5, R4))
        O2=self.C7(self.U2(O1,R3))
        O3 = self.C8(self.U3(O2, R2))
        O4=self.C9(self.U4(O3,R1))

        return  self.Pre(O4)


if __name__ == '__main__':
    #测试网络结构
    x = torch.randn(1, 3, 512, 512)
    net = UNet()
    out = net(x)
    print(out.shape)

 

Guess you like

Origin blog.csdn.net/GWENGJING/article/details/127616048