语义分割——U-Net++(三)

简介
这篇论文《UNet++: A Nested U-Net Architecture for Medical Image Segmentation》是2018年6月的文章,DLMIA2018会议。文章对Unet改进的点主要是skip connection。
原论文地址:https://arxiv.org/pdf/1807.10165.pdf
作者知乎解读:https://zhuanlan.zhihu.com/p/44958351

一、铺垫
U-Net和FCN非常的相似,它们的结构用了一个比较经典的思路,也就是编码和解码(encoder-decoder)。当时这个结构提出的主要作用并不是分割,而是压缩图像和去噪声。输入是一幅图,经过下采样的编码,得到一串比原先图像更小的特征,相当于压缩,然后再经过一个解码,理想状况就是能还原到原来的图像。这样的话我们存一幅图的时候就只需要存一个特征和一个解码器即可。
后来把这个思路被用在了图像分割的问题上,也就是现在我们看到的U-Net结构。简单来讲:输入一幅图,编码(下采样),然后解码(上采样),然后输出一个分割结果。
先回顾一下U-Net结构:在这里插入图片描述

二、展开
U-net这个三年不动的拓扑结构真的一点儿毛病都没有吗?当然不是。首先原论文给出的结构是原图经过四次下采样,四次上采样,得到分割结果,实际呢,为什么四次?或许作者使用的数据集,四次下采样的效果好;或许四次下采样的接受域或者感受野大小正合适处理图像;或者四次降采样比较适合输入图像的尺寸等等。所以引出第一个问题:网络到底要多深最好?
很多论文给出了他们建议的网络结构,其中包括非常多的细节,比如用什么卷积,用几层,怎么降采样,学习率多少,优化器用什么,这些都是比较直观的参数,其实这些在论文中给出参数并不见得是最好的,所以关注这些的意义不大,一个网络结构,我们真正值得关注的是它的设计传达了什么信息。
比如一个2017年在CVPR上发表的一个名叫PSPNet的分割网络,你会发现,好像整体的架构和U-Net还是像的,只是降采样的数目减小了,当然,他们也针对性的增强了中间的特征抓取环节的复杂性。再比如Yoshua Bengio组最近的关于图像分割的论文《Fully Convolutional DenseNets for Semantic Segmentation》,这是他们提出的结构,名叫提拉米苏。也是U形结构,但你会发现,他们就只用了三次降采样。所以到底要多深,是不是越深越好?在这里插入图片描述
好,我们回来继续讨论到底需要多深的问题。其实这个是非常灵活的,涉及到的一个点就是特征提取器,U-Net和FCN为什么成功,因为它相当于给了一个网络的框架,具体用什么特征提取器,随便。这个时候,高引就出现了,各种在encoder上的微创新络绎不绝,最直接的就是用ImageNet里面的明星结构来套嘛,前几年的BottleNeck,Residual,还有去年的DenseNet,就比谁出文章快。这一类的论文就相当于从1到10的递进,而U-Net这个低层结构的提出却是从0到1。说句题外话,像这种从1到10的论文,引用往往不会比从0到1的论文高,因为它不自觉的局限了自己的扩展空间。
关于到底要多深这个问题,还有一个引申的问题就是,下采样对于分割网络到底是不是必须的?既然输入和输出都是相同大小的图,为什么要折腾去下采样一下再上采样呢?比较直接的回答就是下采样的理论意义:它可以增加对输入图像的一些小扰动的鲁棒性,比如图像平移,旋转等,减少过拟合的风险,降低运算量,和增加感受野的大小。在这里插入图片描述
在我的理解中,对于特征提取阶段,浅层结构可以抓取图像的一些简单的特征,比如边界,颜色,而深层结构因为感受野大了,而且经过的卷积操作多了,能抓取到图像的一些说不清道不明的抽象特征。总之,浅有浅的侧重,深有深的优势。那既然浅层特征和深层特征都很重要,U-Net为什么只在4层以后才返回去,也就是只去抓深层特征。在这里插入图片描述

三、主体
沿着上面的思路继续,如果排除一切其他干扰,既然我们不知道需要多深,是不是会衍生出一系列像这样子的U-Net,它们的深度不一。为了搞清楚是不是越深越好,我们应该对它们做一下实验,看看它们各自的分割表现。然后我们得出:不同层次特征的重要性对于不同的数据集是不一样的,并不是说我设计一个4层的U-Net,就一定对所有数据集的分割问题都最优。在这里插入图片描述
那么接下来我们的目标就是使用浅层和深层的特征。但是总不能训练这么多个U-Net吧。所以像下图这样就把1~4层的U-Net全给连一起了。它们的子集包含1层U-Net,2层U-Net,以此类推。这个结构的好处就是我不管你哪个深度的特征有效,我干脆都给你用上,让网络自己去学习不同深度的特征的重要性。第二个好处是它共享了一个特征提取器,只训练一个encoder,它的不同层次的特征由不同的decoder路径来还原。这个encoder依旧可以灵活替换。在这里插入图片描述
可惜的是,这个网络结构是不能被训练的,不会由任何梯度会经过这个红色区域,因为它和算loss function的地方是在反向传播时是断开的。于是我们顺着思路一直走,就想到了这个样子。在这里插入图片描述
这个结构由UC Berkeley的团队提出,发表在今年的CVPR上,是一个oral的论文,题目是"Deep Layer Aggregation"。不难发现这个结构强行去掉了U-Net本身自带的长连接。取而代之的是一系列的短连接。
那么我们来看看U-Net引以为傲的长连接到底有什么优点。在这里插入图片描述
U-Net中的长连接是有必要的,它联系了输入图像的很多信息,有助于还原降采样所带来的信息损失,在一定程度上,我觉得它和残差的操作非常类似,也就是residual操作,x+f(x)。因此,我们需要给出一个综合长连接和短连接的方案。
这个结构,就是在MICCAI中发表的UNet++。
它可以抓取不同层次的特征,将它们通过特征叠加的方式整合,不同层次的特征,或者说不同大小的感受野,对于大小不一的目标对象的敏感度是不同的。比如,感受野大的特征,可以很容易的识别出大物体的,但是在实际分割中,大物体边缘信息和小物体本身是很容易被深层网络一次次的下采样和一次次上采样给弄丢的,这个时候就可能需要感受野小的特征来帮助。

在这里插入图片描述
刚才说这个结构在反向传播的时候中间部分会收不到过来的梯度,如果只用最右边的一个loss来做的话。
一个非常直接的解决方案就是深监督deep supervision。这个概念不是新的,有很多对U-Net对改进论文中也有用到,具体的实现操作就是在图中X0.1后面加一个1*1的卷积核,相当于去监督每个level,或者每个分支的U-Net的输出。
在训练过程中在各个level的子网络中加了这种深监督,好处是在测试时可以be pruned(剪枝)。在这里插入图片描述
关注被剪掉的这部分,你会发现,在测试的阶段,由于输入的图像只会前向传播,扔掉这部分对前面的输出完全没有影响的,而在训练阶段,因为既有前向,又有反向传播,被剪掉的部分是会帮助其他部分做权重更新的。因为在深监督的过程中,每个子网络的输出都其实已经是图像的分割结果了,所以如果小的子网络的输出结果已经足够好了,我们可以随意的剪掉那些多余的部分了。

四、代码

#https://github.com/MrGiovanni/UNetPlusPlus
import keras
import tensorflow as tf
from keras.models import Model
from keras import backend as K
from keras.layers import Input, merge, Conv2D, ZeroPadding2D, UpSampling2D, Dense, concatenate, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D, GlobalAveragePooling2D, MaxPooling2D
from keras.layers.core import Dense, Dropout, Activation
from keras.layers import BatchNormalization, Dropout, Flatten, Lambda
from keras.layers.advanced_activations import ELU, LeakyReLU
from keras.optimizers import Adam, RMSprop, SGD
from keras.regularizers import l2
from keras.layers.noise import GaussianDropout

import numpy as np
smooth = 1.
dropout_rate = 0.5
act = "relu"
########################################
# 2D Standard
########################################

def standard_unit(input_tensor, stage, nb_filter, kernel_size=3):

    x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(input_tensor)
    x = Dropout(dropout_rate, name='dp'+stage+'_1')(x)
    x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(x)
    x = Dropout(dropout_rate, name='dp'+stage+'_2')(x)

    return x

########################################

"""
Standard UNet++ [Zhou et.al, 2018]
Total params: 9,041,601
"""
def Nest_Net(img_rows, img_cols, color_type=1, num_class=1, deep_supervision=False):

    nb_filter = [32,64,128,256,512]

    # Handle Dimension Ordering for different backends
    global bn_axis
    if K.image_dim_ordering() == 'tf':
      bn_axis = 3
      img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input')
    else:
      bn_axis = 1
      img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input')

    conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0])
    pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1)

    conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1])
    pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1)

    up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(conv2_1)
    conv1_2 = concatenate([up1_2, conv1_1], name='merge12', axis=bn_axis)
    conv1_2 = standard_unit(conv1_2, stage='12', nb_filter=nb_filter[0])

    conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2])
    pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)

    up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1)
    conv2_2 = concatenate([up2_2, conv2_1], name='merge22', axis=bn_axis)
    conv2_2 = standard_unit(conv2_2, stage='22', nb_filter=nb_filter[1])

    up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2)
    conv1_3 = concatenate([up1_3, conv1_1, conv1_2], name='merge13', axis=bn_axis)
    conv1_3 = standard_unit(conv1_3, stage='13', nb_filter=nb_filter[0])

    conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3])
    pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)

    up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1)
    conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=bn_axis)
    conv3_2 = standard_unit(conv3_2, stage='32', nb_filter=nb_filter[2])

    up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2)
    conv2_3 = concatenate([up2_3, conv2_1, conv2_2], name='merge23', axis=bn_axis)
    conv2_3 = standard_unit(conv2_3, stage='23', nb_filter=nb_filter[1])

    up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3)
    conv1_4 = concatenate([up1_4, conv1_1, conv1_2, conv1_3], name='merge14', axis=bn_axis)
    conv1_4 = standard_unit(conv1_4, stage='14', nb_filter=nb_filter[0])

    conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4])

    up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
    conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis)
    conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3])

    up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
    conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=bn_axis)
    conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2])

    up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
    conv2_4 = concatenate([up2_4, conv2_1, conv2_2, conv2_3], name='merge24', axis=bn_axis)
    conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1])

    up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
    conv1_5 = concatenate([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], name='merge15', axis=bn_axis)
    conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0])

    nestnet_output_1 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_2)
    nestnet_output_2 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_3)
    nestnet_output_3 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_3', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_4)
    nestnet_output_4 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_4', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5)

    if deep_supervision:
        model = Model(input=img_input, output=[nestnet_output_1,
                                               nestnet_output_2,
                                               nestnet_output_3,
                                               nestnet_output_4])
    else:
        model = Model(input=img_input, output=[nestnet_output_4])

    return model

代码来源:https://github.com/MrGiovanni/UNetPlusPlus

猜你喜欢

转载自blog.csdn.net/qq_42823043/article/details/90438874