Intensive reading of deep learning papers [6]: UNet++

0db5b2c77fa0a2e90235848d384b607b.jpeg

Since the encoding and decoding structure of UNet was proposed, it has a tendency to unify deep learning image segmentation, and the subsequent improvement scheme based on UNet has been enduring for a long time. Some researchers are also thinking about the effectiveness of UNet from the network structure itself. For example, how many layers should be used in the codec network, whether the skip connection can have more changes, and what kind of structure is more effective for training. UNet itself is a network structure proposed for the task of medical image segmentation. Unlike natural image segmentation, this task does not require very strict segmentation accuracy. However, for medical images, the segmentation of organs and lesions requires extremely high accuracy, because in many cases, the segmentation effect is directly related to the corresponding clinical diagnosis decision. For the motivation of the above two aspects, that is, to design a better UNet structure and improve the accuracy of medical image segmentation, relevant researchers proposed a nested UNet structure (Nested UNet), also called UNet++, and proposed a UNet++ paper as UNet++: A Nested U-Net Architecture for Medical Image Segmentation, published at the 2018 Medical Image Computing and Computer Assisted Intervention (MICCAI) conference.

UNet++ is named Nested UNet because its overall encoding and decoding network structure also nests encoding and decoding sub-networks (sub-networks). On this basis, the jump connection in the middle of UNet is redesigned, and deep supervision is added. The mechanism accelerates network training convergence. The complete UNet++ structure is shown in the figure below.

3d8ea073864e63002bc3bbc95c502e0c.png

The black part in the figure is the original UNet structure, including three parts: encoder downsampling, decoder upsampling and black dotted line jump connection; the green part is the nested UNet subnetwork, including convolution and upsampling. The blue dotted line is the redesigned jump connection of UNet++, which is similar to the dense connection of DenseNet, here is to provide jump connections for the sub-network; the top red and black connection is the deep supervision mechanism supplemented by UNet++, the purpose is to network able to train successfully.

Let's interpret UNet++ from the perspective of structural design. Regarding the UNet structure, the most important question is how many layers the network should have. The original UNet structure uses 4 layers of downsampling and 4 layers of upsampling, so is 4 layers enough to meet the needs of all segmentation tasks? the answer is negative. Through the network structure analysis before this section, we already know that the shallow network can extract the coarse-grained features of the image and obtain the basic shape of the image; the deep network can extract the abstract features of the image and obtain the semantic information of the image. There are deep benefits. Like the previous RefineNet point of view, the author of UNet++ believes that whether it is shallow, deep or middle, the features of all levels are important for the final segmentation. Some data segmentation tasks are simple and the image information is single, and a shallow network may be sufficient to achieve good results, while some data tasks are complex and image information is rich, and may require a deeper network structure to achieve good results. The previous UNet Structural design is difficult to take care of this universality at the same time. UNet++ achieves this universality by designing nested UNet subnetworks of different depths, so the depth of UNet is solved here.

The second question is how to adjust the skip connection part after adding nested networks of different depths. In UNet, the skip connection is directly connected from the encoder of the same layer to the corresponding layer of the encoder upsampling. But after adding the nested subnetwork, the original long connections in UNet no longer exist, and are replaced by short connections in each subnetwork. The authors of UNet++ believe that the long connection is necessary in UNet, which can link the front and back information in the image, and has a good supplementary effect on the information loss caused by downsampling. Therefore, UNet++ refers to the dense connection design of DenseNet, and adds long connections to the nested network, as shown in Figure 5 below.

88d0015a4fd980c2976db0dacbe9b805.png

But this brings a third problem: when backpropagating, the middle part may not receive the gradient returned by the loss function. So see the tricks, UNet++ uses the deep supervision method to force the gradient to help the network to train normally. But the benefits of deep supervision for UNet++ are not limited to this. Through different deep supervision loss functions, UNet++ can achieve scalability through network pruning. So, in summary, UNet++ has the following two advantages compared to the original UNet:

(1) Integrate image features at different levels through nested sub-networks and long-short connections, making network segmentation more accurate;

(2) The flexible network structure cooperates with the deep supervision mechanism, so that the deep network with a huge amount of parameters can greatly reduce the amount of parameters within the acceptable accuracy range.

The comparison of network segmentation effects between UNet++ and UNet is shown in the figure below.

d282f25e9f66497b5273fdde1a4e058d.png

UNet++ has also further expanded the UNet family network, and there are many subsequent improved versions based on it, such as Attention UNet++, UNet 3+, etc. The following code gives an implementation reference of UNet++. The complete code can refer to:

https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py

class NestedUNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
        super().__init__()
        nb_filter = [32, 64, 128, 256, 512]
        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)


        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])


        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])


        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])


        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])


        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)




    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))


        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))


        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))


        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))


        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]
        else:
            output = self.final(x0_4)
            return output

According to readers' feedback on this series, the subsequent relevant content will be gradually updated to the GitHub address of Deep Learning Semantic Segmentation and Practical Guide:

https://github.com/luwill/Semantic-Segmentation-Guide

Past highlights:

 Intensive reading of deep learning papers [1]: FCN full convolutional network

 Intensive reading of deep learning papers [2]: UNet network

 Intensive reading of deep learning papers [3]: SegNet

 Intensive reading of deep learning papers [4]: ​​RefineNet

 Intensive reading of deep learning papers [5]: Attention UNet

 Explainer video is here! Machine learning formula derivation and code implementation open record!

 end! "Machine Learning Formula Derivation and Code Implementation" PPT download of chapters 1-26

Guess you like

Origin blog.csdn.net/weixin_37737254/article/details/125923940