Summary of Unet series of papers

Why do Unet algorithms achieve such good results in medical image processing? A large part of the reason is determined by the characteristics of the medical image itself. First, the structure of the medical image is fixed and simple. Secondly, the Unet network structure is simple and has a high degree of matching with difficult-to-obtain medical images. The resulting problem is that the powerful feature representation ability of deep learning cannot play its role well, and finally it is very difficult to obtain data sets. Most people can't get it, even if they get a lot of medical image data, labeling it is really a boring job. It is better to leave the professional work to professional people. In real life, there may be the following problems. People who engage in CV have difficulties in labeling medical images, and it is foolish to distinguish between normal and abnormal. Doctors may also be relatively unfamiliar with CV, not for everyone, but for most people, this is the case.

1. Dreams

U-Net: Convolutional Networks for Biomedical Image Segmentation
code
paper
insert image description here

As can be seen from the Unet network structure diagram, the network structure is an encoding and decoding structure, and the encoding is composed of convolution and pooling, which can also be understood as a downsampling operation. The purpose is to obtain feature maps of different sizes. The decoder is composed of convolution, feature splicing, and upsampling. The feature splicing is splicing in the feature map dimension. The purpose is to obtain a thicker feature map, because the convolution pooling process in the encoding stage will cause the loss of image detail information. , feature map splicing is to retrieve as much as possible the image detail information lost in the encoding stage. Although the stitching operation can avoid the loss of image information, how much of the lost information can this operation retrieve? Is the splicing method good or bad? is a question worth considering. The Unet network has become the current baseline for medical image segmentation, mainly due to some data characteristics of the medical image itself. The characteristics of most medical image data are as follows: first, the semantics of medical images are relatively simple and their structure is fixed compared with natural scenes, so all their features are very important, that is to say, low-level and high-level semantic information should be preserved as much as possible down so that the model can learn it better. Secondly, it is difficult to obtain medical data, and the amount of data that can be obtained is too small, which leads to the contradiction between too deep a network and a small amount of data. There will be overfitting phenomenon. Finally, compared with other segmentation models, Unet has a simple structure and a larger operating space.
Most of the follow-up improvement work revolves around feature extraction and feature splicing

Torch is implemented as follows

import torch
import torch.nn as nn
import torch.nn.functional as F

class double_conv2d_bn(nn.Module):
    def__init__(self,in_channels,out_channels,kernel_size=3,strides=1,padding=1):
        super(double_conv2d_bn,self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,
                               kernel_size=kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.conv2 = nn.Conv2d(out_channels,out_channels,
                              kernel_size = kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out
    
class deconv2d_bn(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=2,strides=2):
        super(deconv2d_bn,self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels,out_channels,
                                        kernel_size = kernel_size,
                                       stride = strides,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out
    
class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
        self.layer1_conv = double_conv2d_bn(1,8)
        self.layer2_conv = double_conv2d_bn(8,16)
        self.layer3_conv = double_conv2d_bn(16,32)
        self.layer4_conv = double_conv2d_bn(32,64)
        self.layer5_conv = double_conv2d_bn(64,128)
        self.layer6_conv = double_conv2d_bn(128,64)
        self.layer7_conv = double_conv2d_bn(64,32)
        self.layer8_conv = double_conv2d_bn(32,16)
        self.layer9_conv = double_conv2d_bn(16,8)
        self.layer10_conv = nn.Conv2d(8,1,kernel_size=3,
                                     stride=1,padding=1,bias=True)
        
        self.deconv1 = deconv2d_bn(128,64)
        self.deconv2 = deconv2d_bn(64,32)
        self.deconv3 = deconv2d_bn(32,16)
        self.deconv4 = deconv2d_bn(16,8)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,x):
        conv1 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1,2)
        
        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2,2)
        
        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3,2)
        
        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4,2)
        
        conv5 = self.layer5_conv(pool4)
        
        convt1 = self.deconv1(conv5)
        concat1 = torch.cat([convt1,conv4],dim=1)
        conv6 = self.layer6_conv(concat1)
        
        convt2 = self.deconv2(conv6)
        concat2 = torch.cat([convt2,conv3],dim=1)
        conv7 = self.layer7_conv(concat2)
        
        convt3 = self.deconv3(conv7)
        concat3 = torch.cat([convt3,conv2],dim=1)
        conv8 = self.layer8_conv(concat3)
        
        convt4 = self.deconv4(conv8)
        concat4 = torch.cat([convt4,conv1],dim=1)
        conv9 = self.layer9_conv(concat4)
        outp = self.layer10_conv(conv9)
        outp = self.sigmoid(outp)
        return outp
    

model = Unet()
inp = torch.rand(10,1,224,224)
outp = model(inp)
print(outp.shape)

2. DC Dreams

DC-UNet: Rethinking the U-Net Architecture with Dual Channel Efficient CNN for Medical Images Segmentation
Problem: The Unet method has become the current mainstream medical image segmentation algorithm, but since the original Unet network is mainly composed of codecs. The image features cannot be extracted efficiently.
Solution: Design an efficient CNN architecture to replace the encoder and decoder, and apply the residual module to replace the skip connection between the encoder and decoder to improve the existing U-Net model .
insert image description here
The network structure is not much different from Unet,
keras implementation

# -*- coding: utf-8 -*-
import os
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from keras import initializers
from keras.layers import SpatialDropout2D, Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, AveragePooling2D, \
    UpSampling2D, BatchNormalization, Activation, add, Dropout, Permute, ZeroPadding2D, Add, Reshape
from keras.models import Model, model_from_json
from keras.optimizers import Adam
from keras.layers.advanced_activations import ELU, LeakyReLU, ReLU, PReLU
from keras.utils.vis_utils import plot_model
from keras import backend as K
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras import applications, optimizers, callbacks
import matplotlib
import keras
import tensorflow as tf
from keras.layers import *


def conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(1, 1), activation='relu', name=None):
    '''
    2D Convolutional layers

    Arguments:
        x {keras layer} -- input layer
        filters {int} -- number of filters
        num_row {int} -- number of rows in filters
        num_col {int} -- number of columns in filters

    Keyword Arguments:
        padding {str} -- mode of padding (default: {'same'})
        strides {tuple} -- stride of convolution operation (default: {(1, 1)})
        activation {str} -- activation function (default: {'relu'})
        name {str} -- name of the layer (default: {None})

    Returns:
        [keras layer] -- [output layer]
    '''

    x = Conv2D(filters, (num_row, num_col), strides=strides, padding=padding, use_bias=False)(x)
    x = BatchNormalization(axis=3, scale=False)(x)

    if (activation == None):
        return x

    x = Activation(activation, name=name)(x)

    return x


def trans_conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(2, 2), name=None):
    '''
    2D Transposed Convolutional layers

    Arguments:
        x {keras layer} -- input layer
        filters {int} -- number of filters
        num_row {int} -- number of rows in filters
        num_col {int} -- number of columns in filters

    Keyword Arguments:
        padding {str} -- mode of padding (default: {'same'})
        strides {tuple} -- stride of convolution operation (default: {(2, 2)})
        name {str} -- name of the layer (default: {None})

    Returns:
        [keras layer] -- [output layer]
    '''

    x = Conv2DTranspose(filters, (num_row, num_col), strides=strides, padding=padding)(x)
    x = BatchNormalization(axis=3, scale=False)(x)

    return x


def DCBlock(U, inp, alpha=1.67):
    '''
    DC Block

    Arguments:
        U {int} -- Number of filters in a corrsponding UNet stage
        inp {keras layer} -- input layer

    Returns:
        [keras layer] -- [output layer]
    '''

    W = alpha * U

    # shortcut = inp

    # shortcut = conv2d_bn(shortcut, int(W*0.167) + int(W*0.333) +
    #                      int(W*0.5), 1, 1, activation=None, padding='same')

    conv3x3_1 = conv2d_bn(inp, int(W * 0.167), 3, 3,
                          activation='relu', padding='same')

    conv5x5_1 = conv2d_bn(conv3x3_1, int(W * 0.333), 3, 3,
                          activation='relu', padding='same')

    conv7x7_1 = conv2d_bn(conv5x5_1, int(W * 0.5), 3, 3,
                          activation='relu', padding='same')

    out1 = concatenate([conv3x3_1, conv5x5_1, conv7x7_1], axis=3)
    out1 = BatchNormalization(axis=3)(out1)

    conv3x3_2 = conv2d_bn(inp, int(W * 0.167), 3, 3,
                          activation='relu', padding='same')

    conv5x5_2 = conv2d_bn(conv3x3_2, int(W * 0.333), 3, 3,
                          activation='relu', padding='same')

    conv7x7_2 = conv2d_bn(conv5x5_2, int(W * 0.5), 3, 3,
                          activation='relu', padding='same')
    out2 = concatenate([conv3x3_2, conv5x5_2, conv7x7_2], axis=3)
    out2 = BatchNormalization(axis=3)(out2)

    out = add([out1, out2])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    return out


def ResPath(filters, length, inp):
    '''
    ResPath

    Arguments:
        filters {int} -- [description]
        length {int} -- length of ResPath
        inp {keras layer} -- input layer

    Returns:
        [keras layer] -- [output layer]
    '''

    shortcut = inp
    shortcut = conv2d_bn(shortcut, filters, 1, 1,
                         activation=None, padding='same')

    out = conv2d_bn(inp, filters, 3, 3, activation='relu', padding='same')

    out = add([shortcut, out])
    out = Activation('relu')(out)
    out = BatchNormalization(axis=3)(out)

    for i in range(length - 1):
        shortcut = out
        shortcut = conv2d_bn(shortcut, filters, 1, 1,
                             activation=None, padding='same')

        out = conv2d_bn(out, filters, 3, 3, activation='relu', padding='same')

        out = add([shortcut, out])
        out = Activation('relu')(out)
        out = BatchNormalization(axis=3)(out)

    return out


def DCUNet(height, width, channels):
    '''
    DC-UNet

    Arguments:
        height {int} -- height of image
        width {int} -- width of image
        n_channels {int} -- number of channels in image

    Returns:
        [keras model] -- MultiResUNet model
    '''

    inputs = Input((height, width, channels))

    dcblock1 = DCBlock(32, inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(dcblock1)
    dcblock1 = ResPath(32, 4, dcblock1)

    dcblock2 = DCBlock(32 * 2, pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(dcblock2)
    dcblock2 = ResPath(32 * 2, 3, dcblock2)

    dcblock3 = DCBlock(32 * 4, pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2))(dcblock3)
    dcblock3 = ResPath(32 * 4, 2, dcblock3)

    dcblock4 = DCBlock(32 * 8, pool3)
    pool4 = MaxPooling2D(pool_size=(2, 2))(dcblock4)
    dcblock4 = ResPath(32 * 8, 1, dcblock4)

    dcblock5 = DCBlock(32 * 16, pool4)

    up6 = concatenate([Conv2DTranspose(
        32 * 8, (2, 2), strides=(2, 2), padding='same')(dcblock5), dcblock4], axis=3)
    dcblock6 = DCBlock(32 * 8, up6)

    up7 = concatenate([Conv2DTranspose(
        32 * 4, (2, 2), strides=(2, 2), padding='same')(dcblock6), dcblock3], axis=3)
    dcblock7 = DCBlock(32 * 4, up7)

    up8 = concatenate([Conv2DTranspose(
        32 * 2, (2, 2), strides=(2, 2), padding='same')(dcblock7), dcblock2], axis=3)
    dcblock8 = DCBlock(32 * 2, up8)

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(
        2, 2), padding='same')(dcblock8), dcblock1], axis=3)
    dcblock9 = DCBlock(32, up9)

    conv10 = conv2d_bn(dcblock9, 1, 1, 1, activation='sigmoid')

    model = Model(inputs=[inputs], outputs=[conv10])

    return model


3.TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

Motivation: Deep learning algorithms have achieved great success in the field of medical image processing. In many medical scenarios, medical image segmentation is taken as an example. Due to some characteristics of the medical image itself, Unet has become a baseline in the field of medical image segmentation. Unet is essentially convolution pooling to form multi-scale features. This process will cause some details of the image to be lost. That is to say, proposing a stronger codec can help alleviate the loss of image information, so the combination of transformer and Unet is used to propose TransUnet for medical image segmentation, and the effectiveness of TransUnet is verified through experiments.
The transUnet network structure is shown in the figure below.
insert image description here
Transformer means that it is about to be broken! ! ! !

Summary and Tucao

After reading about ten papers in the Unet series, compared with the original Unet, the performance has improved to a certain extent. However, the complexity of the model is also increasing, and the number of parameters and calculation time are greatly increased. It's not that small repairs are not allowed, most people still do small repairs. There are not too many bright spots, hahaha, to borrow the words of the boss, even if you can't do it, you have to learn to judge whether a job is a good job.
Expect to see groundbreaking work.

Guess you like

Origin blog.csdn.net/hasque2019/article/details/127052635