Pytorch Deep Learning Practical Tutorial (2): UNet Semantic Segmentation Network

Pytorch deep learning practical tutorial (two): UNet semantic segmentation network

This article  has been included on GitHub  https://github.com/Jack-Cherish/PythonPark . There are technical dry goods articles, organized learning materials, and first-line manufacturers' interview experience sharing. Welcome to Star and improve it.

I. Introduction

This article belongs to the Pytorch deep learning semantic segmentation tutorial series.

The content of this series of articles are:

  • Basic use of Pytorch
  • Explanation of semantic segmentation algorithm

If you don't understand the principle of semantic segmentation and the construction of the development environment, please see the previous article in this series of tutorials " Pytorch Deep Learning Practical Tutorial (1): Semantic Segmentation Basics and Environment Construction ".

The development environment of this article uses the Windows environment built in the previous article, and the environment is as follows:

Development environment: Windows

Development language: Python3.7.4

Framework version: Pytorch1.3.0

MIRACLES : 10.2

cuDNN : 7.6.0

This article mainly explains the UNet network structure and the coding of corresponding codes .

PS: All the codes appearing in the article can be downloaded on my github, welcome to Follow, Star: click to view

Two, UNet network structure

In the field of semantic segmentation, the pioneering work of the semantic segmentation algorithm based on deep learning is FCN (Fully Convolutional Networks for Semantic Segmentation), and UNet follows the principle of FCN and has been improved accordingly to adapt it to the simple segmentation problem of small samples .

UNet paper address: click to view

To study a deep learning algorithm, you can first look at the network structure, and then understand the network structure, and then Loss calculation method, training method, etc. This article mainly explains the network structure of UNet, and other content will be explained in subsequent chapters.

1. Network structure principle

UNet was first published at the MICCAI conference in 2015. In more than 4 years, the number of papers cited has reached more than 9,700.

UNet has become the baseline for most medical image semantic segmentation tasks, and it has also inspired a large number of researchers to study the U-shaped network structure and published a batch of papers based on the improvement of the UNet network structure.

The two main features of UNet network structure are: U-shaped network structure and Skip Connection.

UNet is a symmetrical network structure, with down sampling on the left and up sampling on the right.

According to the function, the series of down-sampling operations on the left side can be called encoder, and the series of up-sampling operations on the right side can be called decoder.

There are four gray parallel lines in the middle of Skip Connection. Skip Connection is in the process of upsampling, fusion of the feature map in the process of downsampling.

The fusion operation used by Skip Connection is also very simple, that is, to superimpose the channels of the feature map, commonly known as Concat.

The Concat operation is also easy to understand. For example: a book A with a size of 10cm*10cm and a thickness of 3cm, and a book B with a size of 10cm*10cm and a thickness of 4cm.

Stack Book A and Book B with the edges aligned. In this way, a stack of books with a size of 10cm*10cm and a thickness of 7cm is obtained, similar to this:

 

This kind of "stacking together" operation is Concat.

In the same way, for a feature map, a feature map with a size of 256*256*64, that is, the w (width) of the feature map is 256, h (height) is 256, and c (the number of channels) is 64. Concat fusion with a feature map with a size of 256*256*32 will result in a feature map with a size of 256*256*96.

In actual use, the sizes of the two feature maps fused by Concat are not necessarily the same. For example, the feature map of 256*256*64 and the feature map of 240*240*32 are used for Concat.

In this case, there are two ways:

The first type: crop the large 256*256*64 feature map to a 240*240*64 feature map, such as up, down, left, and right, discarding 8 pixels, and then perform Concat after cropping to get 240*240*96 feature map.

The second type: padding the small 240*240*32 feature map, the padding is the feature map of 256*256*32, such as up and down, left and right, fill 8 pixels each, and then perform Concat after padding to get 256*256*96 The feature map.

The Concat scheme adopted by UNet is the second type, padding a small feature map, the padding method is to fill in 0, a regular constant padding.

2. Code

Some friends may not know much about Pytorch, and recommend an official tutorial for quick start. In one hour, you can master some basic concepts and Pytorch code writing methods.

Pytorch official foundation: click to view

We will split the entire UNet network into multiple modules for explanation.

DoubleConv module:

Let's take a look at two consecutive convolution operations.

It can be seen from the UNet network that regardless of the down-sampling process or the up-sampling process, each layer will perform two consecutive convolution operations. This operation is repeated many times in the UNet network. You can write a DoubleConv module separately:

import torch.nn as nn

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

Explain, the above Pytorch code: torch.nn.Sequential is a timing container, and Modules will be added to the container in the order in which they are passed in. For example, the operation sequence of the above code: convolution->BN->ReLU->convolution->BN->ReLU.

The in_channels and out_channels of the DoubleConv module can be flexibly set for extended use.

In the network shown in the figure above, in_channels is set to 1, and out_channels is set to 64.

The input image size is 572*572, after a 3*3 convolution with a step size of 1, and a padding of 0, a feature map of 570*570 is obtained, and a feature map of 568*568 is obtained after another convolution.

Calculation formula: O=(H−F+2×P)/S+1

H is the size of the input feature map, O is the size of the output feature map, F is the size of the convolution kernel, P is the padding size, and S is the step size.

Down module:

The UNet network has a total of 4 downsampling processes, and the modular code is as follows:

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

 

The code here is very simple, it is a maxpool pooling layer, down-sampling, and then a DoubleConv module.

At this point, the code for the down-sampling process of the left half of the UNet network has been written, and the next is the up-sampling process of the right half .

Up module:

Of course, the most used upsampling process is upsampling. In addition to the conventional upsampling operation, there is also feature fusion.

This piece of code is also slightly more complicated to implement:

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

The code is more complicated. We can look at it separately. First, the upsampling method defined in the __init__ initialization function and the convolution using DoubleConv. Upsampling defines two methods: Upsample and ConvTranspose2d, that is, bilinear interpolation and deconvolution .

Bilinear interpolation is easy to understand, schematic diagram:

Friends who are familiar with bilinear interpolation should not be unfamiliar with this picture. Simply put: the coordinates of the four points Q11, Q12, Q21, and Q22 are known, R1 is calculated through Q11 and Q21, R2 is calculated through Q12 and Q22, and finally passed R1 and R2 find P, this process is bilinear interpolation.

For a feature map, it is actually to fill in the middle of the pixel. The value of the complemented point is determined by the value of the adjacent pixel.

Deconvolution, as the name suggests, is reverse convolution. Convolution is to make the feature map smaller and smaller, and deconvolution is to make the feature map larger and larger. Diagram:

The blue below is the original picture, the surrounding white dashed square is the padding result, usually 0, and the green above is the image after convolution.

This schematic diagram is a feature map process from 2*2 feature map->4*4.

In the forward propagation function, x1 receives up-sampled data, and x2 receives feature fusion data. The feature fusion method is, as mentioned above, padding the small feature map first, and then concat.

OutConv module:

With the above-mentioned DoubleConv module, Down module, and Up module, the main network structure of UNet can be spelled out. The output of the UNet network needs to integrate the output channels according to the number of divisions. The result is shown in the following figure:

 

The operation is very simple, that is, the transformation of the channel. The above figure shows the classification as 2 (the channel is 2).

Although this operation is very simple, it will be called once. For the sake of beauty and cleanliness, it is also encapsulated.

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

At this point, the modules used in the UNet network have been written, we can put the above module codes in a unet_parts.py file, and then create unet_model.py, according to the UNet network structure, set the input and output channels of each module Number and calling sequence, write the following code:

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""

import torch.nn.functional as F

from unet_parts import *


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
    
if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

Use the command python unet_model.py, if there is no error, you will get the following result:

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down2): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down3): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down4): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (up1): Up(
    (up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up2): Up(
    (up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up3): Up(
    (up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up4): Up(
    (up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (outc): OutConv(
    (conv): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)

After the network is set up, the next step is to use the network for training. The specific implementation will be explained in the next article of this series of tutorials.

Three, summary

  • This article mainly explains the UNet network structure, and carries out a modular combing of the UNet network.
  • The next article explains how to use UNet network to write training code.

Like it and then read it, develop a habit, search on WeChat official account【JackCui-AI】 Follow a stalker who is crawling on the Internet

Pytorch Deep Learning Practical Tutorial (5): Have you sorted your garbage today?

 

Guess you like

Origin blog.csdn.net/c406495762/article/details/106296588