TensorFlow2.x构建U^2Net网络


前言

U^2Net是我之前使用过的一个图像分割的网络,由于效果比较好,所以对其印象比较深刻,同样为了学习,当时用TensorFlow2.x重构了这个网络,虽说现在发有点晚,但为了防止自己忘记,还是把它记录了下来。相关文件链接如下:
论文地址
官方代码(pytorch)
详细解读


一、环境搭建

pip install tensorflow-gpu

二、网络结构

1.网络结构图

网络结构图如下:
在这里插入图片描述

2.导入头文件

import tensorflow.keras as k
import tensorflow as tf
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Conv2D,MaxPool2D,BatchNormalization,ReLU,UpSampling2D

3.RSU结构

RSU解耦股相当于一个小型的unet结构,前半部分通过一系列下采样来提取特征,后半部分通过上采样和concat的方式达到特征融合。需要注意的是:这里的连续4个RSU模块其实可以写的更简化,这里我就直接全部展开来写了,代码如下:

#基本卷积块
class REBNCONV(Model):
    def __init__(self,out_ch=3,dirate=1):
        super(REBNCONV, self).__init__()
        self.conv=Sequential()
        self.conv.add(Conv2D(out_ch,kernel_size=(3,3),strides=(1,1),padding="same",dilation_rate=dirate))
        self.conv.add(BatchNormalization())
        self.conv.add(ReLU())
    def call(self, inputs, training=None, mask=None):
        x=self.conv(inputs)
        return x
#第一个RSU结构
class RSU7(Model):
    def __init__(self,mid_ch=12,out_ch=3):
        super(RSU7, self).__init__()
        self.rebnconv0=REBNCONV(out_ch,dirate=1)

        self.rebnconv1=REBNCONV(mid_ch,dirate=1)
        self.pool1=MaxPool2D(pool_size=(2,2),strides=(2,2))

        self.rebnconv2=REBNCONV(mid_ch,dirate=1)
        self.pool2=MaxPool2D(pool_size=(2,2),strides=(2,2))

        self.rebnconv3=REBNCONV(mid_ch,dirate=1)
        self.pool3=MaxPool2D(pool_size=(2,2),strides=(2,2))

        self.rebnconv4=REBNCONV(mid_ch,dirate=1)
        self.pool4=MaxPool2D(pool_size=(2,2),strides=(2,2))

        self.rebnconv5=REBNCONV(mid_ch,dirate=1)
        self.pool5=MaxPool2D(pool_size=(2,2),strides=(2,2))

        self.rebnconv6=REBNCONV(mid_ch,dirate=1)
        self.rebnconv7=REBNCONV(mid_ch,dirate=2)

        self.rebnconv6d=REBNCONV(mid_ch,dirate=1)
        self.rebnconv5d=REBNCONV(mid_ch,dirate=1)
        self.rebnconv4d=REBNCONV(mid_ch,dirate=1)
        self.rebnconv3d=REBNCONV(mid_ch,dirate=1)
        self.rebnconv2d=REBNCONV(mid_ch,dirate=1)
        self.rebnconv1d=REBNCONV(out_ch,dirate=1)
    def call(self, inputs, training=None, mask=None):
        h0=self.rebnconv0(inputs)
        hx1=self.rebnconv1(h0)
        hx=self.pool1(hx1)

        hx2=self.rebnconv2(hx)
        hx=self.pool2(hx2)

        hx3=self.rebnconv3(hx)
        hx=self.pool3(hx3)

        hx4=self.rebnconv4(hx)
        hx=self.pool4(hx4)

        hx5=self.rebnconv5(hx)
        hx=self.pool5(hx5)

        hx6=self.rebnconv6(hx)
        hx7=self.rebnconv7(hx6)

        hx6d=self.rebnconv6d(tf.concat((hx7,hx6),axis=3))
        hx6d_up=UpSampling2D((2,2),interpolation="bilinear")(hx6d) #上采样

        hx5d=self.rebnconv5d(tf.concat((hx6d_up,hx5),axis=3))
        hx5d_up=UpSampling2D((2,2),interpolation="bilinear")(hx5d)

        hx4d=self.rebnconv4d(tf.concat((hx5d_up,hx4),axis=3))
        hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)

        hx3d=self.rebnconv3d(tf.concat((hx4d_up,hx3),axis=3))
        hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)

        hx2d=self.rebnconv2d(tf.concat((hx3d_up,hx2),axis=3))
        hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)

        hx1d=self.rebnconv1d(tf.concat((hx2d_up,hx1),axis=3))
        return hx1d+h0

#第二个RSU模块
class RSU6(Model):
    def __init__(self,mid_ch=12,out_ch=3):
        super(RSU6, self).__init__()
        self.rebnconv0=REBNCONV(out_ch,dirate=1)

        self.rebnconv1=REBNCONV(mid_ch,dirate=1)
        self.pool1=MaxPool2D(pool_size=(2,2),strides=(2,2))

        self.rebnconv2=REBNCONV(mid_ch,dirate=1)
        self.pool2=MaxPool2D(pool_size=(2,2),strides=(2,2))

        self.rebnconv3=REBNCONV(mid_ch,dirate=1)
        self.pool3=MaxPool2D(pool_size=(2,2),strides=(2,2))

        self.rebnconv4=REBNCONV(mid_ch,dirate=1)
        self.pool4=MaxPool2D(pool_size=(2,2),strides=(2,2))

        self.rebnconv5=REBNCONV(mid_ch,dirate=1)

        self.rebnconv6=REBNCONV(mid_ch,dirate=2)

        self.rebnconv5d=REBNCONV(mid_ch,dirate=1)
        self.rebnconv4d=REBNCONV(mid_ch,dirate=1)
        self.rebnconv3d=REBNCONV(mid_ch,dirate=1)
        self.rebnconv2d=REBNCONV(mid_ch,dirate=1)
        self.rebnconv1d=REBNCONV(out_ch,dirate=1)
    def call(self, inputs, training=None, mask=None):
        hx0=self.rebnconv0(inputs)

        hx1=self.rebnconv1(hx0)
        hx=self.pool1(hx1)

        hx2=self.rebnconv2(hx)
        hx=self.pool2(hx2)

        hx3=self.rebnconv3(hx)
        hx=self.pool3(hx3)

        hx4=self.rebnconv4(hx)
        hx=self.pool4(hx4)

        hx5=self.rebnconv5(hx)
        hx6=self.rebnconv6(hx5)

        hx5d=self.rebnconv5d(tf.concat((hx6,hx5),axis=3))
        hx5d_up= UpSampling2D((2,2),interpolation="bilinear")(hx5d)
        # print(hx5d_up.shape)
        # print(hx4.shape)

        hx4d=self.rebnconv4d(tf.concat((hx5d_up,hx4),axis=3))
        hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)

        hx3d=self.rebnconv3d(tf.concat((hx4d_up,hx3),axis=3))
        hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)

        hx2d=self.rebnconv2d(tf.concat((hx3d_up,hx2),axis=3))
        hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)

        hx1d=self.rebnconv1d(tf.concat((hx2d_up,hx1),axis=3))
        return hx1d+hx0

#第三个RSU模块
class RSU5(Model):
    def __init__(self, mid_ch=12, out_ch=3):
        super(RSU5, self).__init__()
        self.rebnconv0 = REBNCONV(out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(mid_ch, dirate=1)
        self.pool1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))

        self.rebnconv2 = REBNCONV(mid_ch, dirate=1)
        self.pool2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))

        self.rebnconv3 = REBNCONV(mid_ch, dirate=1)
        self.pool3 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))

        self.rebnconv4 = REBNCONV(mid_ch, dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch, dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(out_ch, dirate=1)

    def call(self, inputs, training=None, mask=None):
        hx0 = self.rebnconv0(inputs)

        hx1 = self.rebnconv1(hx0)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(tf.concat((hx5, hx4), axis=3))
        hx4d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx4d)

        hx3d = self.rebnconv3d(tf.concat((hx4d_up, hx3), axis=3))
        hx3d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx3d)

        hx2d = self.rebnconv2d(tf.concat((hx3d_up, hx2), axis=3))
        hx2d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx2d)

        hx1d = self.rebnconv1d(tf.concat((hx2d_up, hx1), axis=3))
        return hx1d + hx0

#第四个RSU模块
class RSU4(Model):
    def __init__(self, mid_ch=12, out_ch=3):
        super(RSU4, self).__init__()
        self.rebnconv0 = REBNCONV(out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(mid_ch, dirate=1)
        self.pool1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))

        self.rebnconv2 = REBNCONV(mid_ch, dirate=1)
        self.pool2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))

        self.rebnconv3 = REBNCONV(mid_ch, dirate=1)
        self.rebnconv4 = REBNCONV(mid_ch, dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(out_ch, dirate=1)

    def call(self, inputs, training=None, mask=None):
        hx0 = self.rebnconv0(inputs)

        hx1 = self.rebnconv1(hx0)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx4 = self.rebnconv4(hx)

        hx3d = self.rebnconv3d(tf.concat((hx4, hx3), axis=3))
        hx3d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx3d)

        hx2d = self.rebnconv2d(tf.concat((hx3d_up, hx2), axis=3))
        hx2d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx2d)

        hx1d = self.rebnconv1d(tf.concat((hx2d_up, hx1), axis=3))
        return hx1d + hx0

4.扩展模块

在经过4个RSU模块后,特征图的分辨率变得很小(输入分辨率为320320的情况下,这里的分辨率为1818),如果继续下采样的话,会导致更多信息的丢失,所以作者在这里用扩展卷积替换掉了concat和上采样的操作,也就是说扩展模块中的所有中间层的特征图和输入进来的特征图有着相同的分辨率。代码如下:

#扩展模块
class RSU4F(Model):
    def __init__(self,mid_ch=12,out_ch=3):
        super(RSU4F, self).__init__()
        self.rebnconv0=REBNCONV(out_ch,dirate=1)
        self.rebnconv1=REBNCONV(mid_ch,dirate=1)
        self.rebnconv2=REBNCONV(mid_ch,dirate=2)
        self.rebnconv3=REBNCONV(mid_ch,dirate=4)
        self.rebnconv4=REBNCONV(mid_ch,dirate=8)
        self.rebnconv3d=REBNCONV(mid_ch,dirate=4)
        self.rebnconv2d=REBNCONV(mid_ch,dirate=2)
        self.rebnconv1d=REBNCONV(out_ch,dirate=1)
    def call(self, inputs, training=None, mask=None):
        hx0=self.rebnconv0(inputs)
        hx1=self.rebnconv1(hx0)
        hx2=self.rebnconv2(hx1)
        hx3=self.rebnconv3(hx2)
        hx4=self.rebnconv4(hx3)
        hx3d=self.rebnconv3d(tf.concat((hx4,hx3),axis=3))
        hx2d=self.rebnconv2d(tf.concat((hx3d,hx2),axis=3))
        hx1d=self.rebnconv1d(tf.concat((hx2d,hx1),axis=3))
        return hx1d+hx0

5.整体结构

这里需要注意的是,在得到每层的概率图时,所使用的激活函数均为sigmoid,整体的网络实现代码如下:

#U^2Net
class U2NET(Model):
    def __init__(self,out_ch=1):
        super(U2NET, self).__init__()
        #encode
        self.stage1=RSU7(32,64)
        self.pool1_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#144*144

        self.stage2=RSU6(32,128)
        self.pool2_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#72*72

        self.stage3=RSU5(64,256)
        self.pool3_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#36*36

        self.stage4=RSU4(128,512)
        self.pool4_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#18*18

        self.stage5=RSU4F(256,512)
        self.pool5_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#9*9

        self.stage6=RSU4F(256,512)

        #decode
        self.stage5d=RSU4F(256,512)
        self.stage4d=RSU4(128,256)
        self.stage3d=RSU5(64,128)
        self.stage2d=RSU6(32,64)
        self.stage1d=RSU7(16,64)

        #每个层的输出
        self.side1=Conv2D(out_ch,kernel_size=(3,3),padding="same")
        self.side2=Conv2D(out_ch,kernel_size=(3,3),padding="same")
        self.side3=Conv2D(out_ch,kernel_size=(3,3),padding="same")
        self.side4=Conv2D(out_ch,kernel_size=(3,3),padding="same")
        self.side5=Conv2D(out_ch,kernel_size=(3,3),padding="same")
        self.side6=Conv2D(out_ch,kernel_size=(3,3),padding="same")
        #最终输出
        self.outconv=Conv2D(out_ch,kernel_size=(1,1))
    def call(self, inputs, training=None, mask=None):
        hx1=self.stage1(inputs)
        hx=self.pool1_1(hx1)

        hx2=self.stage2(hx)
        hx=self.pool2_1(hx2)

        hx3=self.stage3(hx)
        hx=self.pool3_1(hx3)

        hx4=self.stage4(hx)
        hx=self.pool4_1(hx4)

        hx5=self.stage5(hx)
        hx=self.pool5_1(hx5)

        hx6=self.stage6(hx)
        hx6_up=UpSampling2D((2,2),interpolation="bilinear")(hx6)

        #decode
        hx5d=self.stage5d(tf.concat((hx6_up,hx5),axis=3))
        hx5d_up=UpSampling2D((2,2),interpolation="bilinear")(hx5d)

        hx4d=self.stage4d(tf.concat((hx5d_up,hx4),axis=3))
        hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)

        hx3d=self.stage3d(tf.concat((hx4d_up,hx3),axis=3))
        hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)

        hx2d=self.stage2d(tf.concat((hx3d_up,hx2),axis=3))
        hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)

        hx1d=self.stage1d(tf.concat((hx2d_up,hx1),axis=3))

        # side out
        d1=self.side1(hx1d)

        d2=self.side2(hx2d)
        d2=UpSampling2D((2,2),interpolation="bilinear")(d2)

        d3=self.side3(hx3d)
        d3=UpSampling2D((4,4),interpolation="bilinear")(d3)

        d4=self.side4(hx4d)
        d4=UpSampling2D((8,8),interpolation="bilinear")(d4)

        d5=self.side5(hx5d)
        d5=UpSampling2D((16,16),interpolation="bilinear")(d5)

        d6=self.side6(hx6)
        d6=UpSampling2D((32,32),interpolation="bilinear")(d6)

        out=self.outconv(tf.concat((d1,d2,d3,d4,d5,d6),axis=3))
        sig=k.activations.sigmoid #定义激活函数
        return sig(out),sig(d1),sig(d2),sig(d3),sig(d4),sig(d5),sig(d6)

以上便是本篇文章的全部内容,如需训练代码,或发现文章中有错误,欢迎在评论区留言。

猜你喜欢

转载自blog.csdn.net/qq_55068938/article/details/127961156
今日推荐