(Pytorch Advanced Road) U-Net Image Segmentation

overview

At the beginning, u-net was used in biological image segmentation, and the cell electron microscope image was input to U-net to output an image of cell tissue segmentation

The author proposes a U-shaped architecture for the task of image segmentation. The photo is input to the network, and the classification of each pixel is output, such as whether the classification pixel is the target object or the background, and different colors are applied to different classification objects.

Overall model:
Input single-channel 572×572, output 2-channel 388×388, 2-channel is because of the binary classification of pixels, 572 is filled from 388, and the periphery is filled with mirror images, so that the peripheral pixels also have context information

The first stage of the network
First, the single channel 572×572 undergoes 3×3 convolution to obtain 570×570, and the number of output channels is 64

The result is sent to 3×3 convolution to get 568×568 of 64 channels

In the second stage, the pixel area is reduced by 1/2, and the number of channels is expanded by 2 times.
First, the 2×2 max pooling layer is used to reduce 568×568 to 284×284, and the number of channels remains unchanged at 64. As in the first stage, after two 3×3 convolutions, the number of channels is doubled.

The structure of the third, fourth, and fifth stages is the same as that of the second stage, and the area is reduced by 1/2 again, and the number of channels is doubled

The fifth stage came to the bottom of the U-shape. At this time, the size is 32×32. After two 3×3 convolutions, it becomes 28×28, and the number of channels is 1024.

The first stage of the decoder:
to the right side of the U-shape is the reverse process, the area is gradually enlarged, the number of channels is gradually reduced, and the original shape is restored by deconvolution, such as 28 to 56 (up-conv 2×2). When we copy the previous high-pixel feature channel 512, the previous space needs to be cropped and then spliced. This step is called skip concatenate, which is convenient for restoring specific details, and the number of 56×56 channels is 1024. After two Convolutions get 52×52×512

The structure of the second, third and fourth stages of the decoder is the same as that of the first stage, upsampling, the space size becomes larger, and the number of channels is restored to the original after splicing

Finally, it becomes 388×388×64, followed by a classification layer 1×1conv (MLP) output 388×388×2
insert image description here

Unet features: completely convolutional structure, seq2seq model, divided into encoder and decoder

Code

Address
https://github.com/yassouali/pytorch-segmentation
The models folder in an open source project contains many segmented image models

The implementation is for reference only, and the writing in some places is not very standardized

full code

Part of the Unet code is as follows

from base import BaseModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import chain
from base import BaseModel
from utils.helpers import initialize_weights, set_trainable
from itertools import chain
from models import resnet


def x2conv(in_channels, out_channels, inner_channels=None):
    inner_channels = out_channels // 2 if inner_channels is None else inner_channels
    down_conv = nn.Sequential(
        nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(inner_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(inner_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True))
    return down_conv


class encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(encoder, self).__init__()
        self.down_conv = x2conv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)

    def forward(self, x):
        x = self.down_conv(x)
        x = self.pool(x)
        return x


class decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(decoder, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.up_conv = x2conv(in_channels, out_channels)

    def forward(self, x_copy, x, interpolate=True):
        x = self.up(x)

        if (x.size(2) != x_copy.size(2)) or (x.size(3) != x_copy.size(3)):
            if interpolate:
                # Iterpolating instead of padding
                x = F.interpolate(x, size=(x_copy.size(2), x_copy.size(3)),
                                mode="bilinear", align_corners=True)
            else:
                # Padding in case the incomping volumes are of different sizes
                diffY = x_copy.size()[2] - x.size()[2]
                diffX = x_copy.size()[3] - x.size()[3]
                x = F.pad(x, (diffX // 2, diffX - diffX // 2,
                                diffY // 2, diffY - diffY // 2))

        # Concatenate
        x = torch.cat([x_copy, x], dim=1)
        x = self.up_conv(x)
        return x


class UNet(BaseModel):
    def __init__(self, num_classes, in_channels=3, freeze_bn=False, **_):
        super(UNet, self).__init__()

        self.start_conv = x2conv(in_channels, 64)
        self.down1 = encoder(64, 128)
        self.down2 = encoder(128, 256)
        self.down3 = encoder(256, 512)
        self.down4 = encoder(512, 1024)

        self.middle_conv = x2conv(1024, 1024)

        self.up1 = decoder(1024, 512)
        self.up2 = decoder(512, 256)
        self.up3 = decoder(256, 128)
        self.up4 = decoder(128, 64)
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
        self._initialize_weights()

        if freeze_bn:
            self.freeze_bn()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

    def forward(self, x):
        x1 = self.start_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.middle_conv(self.down4(x4))

        x = self.up1(x4, x)
        x = self.up2(x3, x)
        x = self.up3(x2, x)
        x = self.up4(x1, x)

        x = self.final_conv(x)
        return x

    def get_backbone_params(self):
        # There is no backbone for unet, all the parameters are trained from scratch
        return []

    def get_decoder_params(self):
        return self.parameters()

    def freeze_bn(self):
        for module in self.modules():
            if isinstance(module, nn.BatchNorm2d): module.eval()




"""
-> Unet with a resnet backbone
"""

class UNetResnet(BaseModel):
    def __init__(self, num_classes, in_channels=3, backbone='resnet50', pretrained=True, freeze_bn=False, freeze_backbone=False, **_):
        super(UNetResnet, self).__init__()
        model = getattr(resnet, backbone)(pretrained, norm_layer=nn.BatchNorm2d)

        self.initial = list(model.children())[:4]
        if in_channels != 3:
            self.initial[0] = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.initial = nn.Sequential(*self.initial)

        # encoder
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4

        # decoder
        self.conv1 = nn.Conv2d(2048, 192, kernel_size=3, stride=1, padding=1)
        self.upconv1 =  nn.ConvTranspose2d(192, 128, 4, 2, 1, bias=False)

        self.conv2 = nn.Conv2d(1152, 128, kernel_size=3, stride=1, padding=1)
        self.upconv2 = nn.ConvTranspose2d(128, 96, 4, 2, 1, bias=False)

        self.conv3 = nn.Conv2d(608, 96, kernel_size=3, stride=1, padding=1)
        self.upconv3 = nn.ConvTranspose2d(96, 64, 4, 2, 1, bias=False)

        self.conv4 = nn.Conv2d(320, 64, kernel_size=3, stride=1, padding=1)
        self.upconv4 = nn.ConvTranspose2d(64, 48, 4, 2, 1, bias=False)
        
        self.conv5 = nn.Conv2d(48, 48, kernel_size=3, stride=1, padding=1)
        self.upconv5 = nn.ConvTranspose2d(48, 32, 4, 2, 1, bias=False)

        self.conv6 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(32, num_classes, kernel_size=1, bias=False)

        initialize_weights(self)

        if freeze_bn:
            self.freeze_bn()
        if freeze_backbone: 
            set_trainable([self.initial, self.layer1, self.layer2, self.layer3, self.layer4], False)

    def forward(self, x):
        H, W = x.size(2), x.size(3)
        x1 = self.layer1(self.initial(x))
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        
        x = self.upconv1(self.conv1(x4))
        x = F.interpolate(x, size=(x3.size(2), x3.size(3)), mode="bilinear", align_corners=True)
        x = torch.cat([x, x3], dim=1)
        x = self.upconv2(self.conv2(x))

        x = F.interpolate(x, size=(x2.size(2), x2.size(3)), mode="bilinear", align_corners=True)
        x = torch.cat([x, x2], dim=1)
        x = self.upconv3(self.conv3(x))

        x = F.interpolate(x, size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True)
        x = torch.cat([x, x1], dim=1)

        x = self.upconv4(self.conv4(x))

        x = self.upconv5(self.conv5(x))

        # if the input is not divisible by the output stride
        if x.size(2) != H or x.size(3) != W:
            x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=True)

        x = self.conv7(self.conv6(x))
        return x

    def get_backbone_params(self):
        return chain(self.initial.parameters(), self.layer1.parameters(), self.layer2.parameters(), 
                    self.layer3.parameters(), self.layer4.parameters())

    def get_decoder_params(self):
        return chain(self.conv1.parameters(), self.upconv1.parameters(), self.conv2.parameters(), self.upconv2.parameters(),
                    self.conv3.parameters(), self.upconv3.parameters(), self.conv4.parameters(), self.upconv4.parameters(),
                    self.conv5.parameters(), self.upconv5.parameters(), self.conv6.parameters(), self.conv7.parameters())

    def freeze_bn(self):
        for module in self.modules():
            if isinstance(module, nn.BatchNorm2d): module.eval()

All modules required are defined in the init function of class UNet

start_conv defines the initial convolution, convolution from channel 1 to 64, x2conv

There are 4 down modules at the back, and the down sampling module is expanded from 64 to 128, 256, 512, 1024 in turn

There is also a convolutional layer 1024 to 1024 in the middle

Followed by 4 up modules, the upsampling is sequentially reduced from 1024 to 512, 256, 128, 64

After the upsampling is completed, there is a final layer of classification layer 1×1 convolutional MLP

These layers are connected in forward

The following is a breakdown of each function one by one

x2conv

Contains two convolutional layers, the core is nn.Sequential, the first layer of convolution is a 3×3, padding=1, the number of channels remains unchanged, so that H and W are reduced by two pixels

The second layer of convolution, the number of channels is changed to out_channel, 3×3, padding is 1

Each stage can use these two convolutions to build

def x2conv(in_channels, out_channels, inner_channels=None):
    inner_channels = out_channels // 2 if inner_channels is None else inner_channels
    down_conv = nn.Sequential(
        nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(inner_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(inner_channels, out_channels, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True))
    return down_conv

encoder

Contains two layers, the first is two 3×3 padding=1 convolutions of the x2conv module, the second part is maxpool2d, kernel_size=2, and the space is compressed to half of the original

class encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(encoder, self).__init__()
        self.down_conv = x2conv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)

    def forward(self, x):
        x = self.down_conv(x)
        x = self.pool(x)
        return x

decoder

Conversely with the encoder, first do 2d deconvolution, do upsampling, input in_channel, output channel in_channel // 2, kernel_size=2, stride=2, complete 2 times upsampling, and then the two-layer convolution of x2conv

In forward, x_copy will be brought in. This x_copy is to take out the part from the encoder, perform concatenate operation with x, and send the spliced ​​x into x2conv convolution, and the number of channels becomes smaller.

class decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(decoder, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.up_conv = x2conv(in_channels, out_channels)

    def forward(self, x_copy, x, interpolate=True):
        x = self.up(x)

        if (x.size(2) != x_copy.size(2)) or (x.size(3) != x_copy.size(3)):
            if interpolate:
                # Iterpolating instead of padding
                x = F.interpolate(x, size=(x_copy.size(2), x_copy.size(3)),
                                mode="bilinear", align_corners=True)
            else:
                # Padding in case the incomping volumes are of different sizes
                diffY = x_copy.size()[2] - x.size()[2]
                diffX = x_copy.size()[3] - x.size()[3]
                x = F.pad(x, (diffX // 2, diffX - diffX // 2,
                                diffY // 2, diffY - diffY // 2))

        # Concatenate
        x = torch.cat([x_copy, x], dim=1)
        x = self.up_conv(x)
        return x

Other applications: WAVE-U-NET, vocal accompaniment separation

Paper address:
https://ismir2018.ismir.net/doc/pdfs/205_Paper.pdf

Project address:
https://github.com/f90/Wave-U-Net

Structure diagram: the input one-dimensional speech waveform, the left is the encoder, which gradually downsamples the one-dimensional waveform, and the right decoder gradually upsamples the waveform, and corresponds to the high sampling of the encoder at each stage of the decoder The features of the rate are spliced ​​together, and finally multiple categories are separated. There are K channels that are finally classified into C channels. Each channel corresponds to a different waveform.
insert image description here

Guess you like

Origin blog.csdn.net/qq_19841133/article/details/126927383