Каталог статей
предисловие
U^2Net — это сеть сегментации изображений, которую я использовал раньше. Поскольку эффект лучше, она меня больше впечатляет. Также для обучения я реконструировал эту сеть с помощью TensorFlow2. На случай, если вы забудете, запишите ее. Ниже приведены ссылки на соответствующие документы: Подробная интерпретация официального кода
бумажного адреса
(pytorch)
1. Строительство окружающей среды
pip install tensorflow-gpu
2. Структура сети
1. Схема структуры сети
Схема структуры сети выглядит следующим образом:
2. Импортируйте заголовочный файл
import tensorflow.keras as k
import tensorflow as tf
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Conv2D,MaxPool2D,BatchNormalization,ReLU,UpSampling2D
3. Структура РСУ
Блок развязки RSU эквивалентен небольшой структуре unet.Первая половина извлекает признаки с помощью серии понижающей дискретизации, а вторая половина обеспечивает слияние признаков с помощью повышающей дискретизации и объединения. Следует отметить, что 4 последовательных модуля RSU здесь на самом деле могут быть написаны более упрощенно.Здесь я прямо раскрою и напишу их все.Код выглядит следующим образом:
#基本卷积块
class REBNCONV(Model):
def __init__(self,out_ch=3,dirate=1):
super(REBNCONV, self).__init__()
self.conv=Sequential()
self.conv.add(Conv2D(out_ch,kernel_size=(3,3),strides=(1,1),padding="same",dilation_rate=dirate))
self.conv.add(BatchNormalization())
self.conv.add(ReLU())
def call(self, inputs, training=None, mask=None):
x=self.conv(inputs)
return x
#第一个RSU结构
class RSU7(Model):
def __init__(self,mid_ch=12,out_ch=3):
super(RSU7, self).__init__()
self.rebnconv0=REBNCONV(out_ch,dirate=1)
self.rebnconv1=REBNCONV(mid_ch,dirate=1)
self.pool1=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv2=REBNCONV(mid_ch,dirate=1)
self.pool2=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv3=REBNCONV(mid_ch,dirate=1)
self.pool3=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv4=REBNCONV(mid_ch,dirate=1)
self.pool4=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv5=REBNCONV(mid_ch,dirate=1)
self.pool5=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv6=REBNCONV(mid_ch,dirate=1)
self.rebnconv7=REBNCONV(mid_ch,dirate=2)
self.rebnconv6d=REBNCONV(mid_ch,dirate=1)
self.rebnconv5d=REBNCONV(mid_ch,dirate=1)
self.rebnconv4d=REBNCONV(mid_ch,dirate=1)
self.rebnconv3d=REBNCONV(mid_ch,dirate=1)
self.rebnconv2d=REBNCONV(mid_ch,dirate=1)
self.rebnconv1d=REBNCONV(out_ch,dirate=1)
def call(self, inputs, training=None, mask=None):
h0=self.rebnconv0(inputs)
hx1=self.rebnconv1(h0)
hx=self.pool1(hx1)
hx2=self.rebnconv2(hx)
hx=self.pool2(hx2)
hx3=self.rebnconv3(hx)
hx=self.pool3(hx3)
hx4=self.rebnconv4(hx)
hx=self.pool4(hx4)
hx5=self.rebnconv5(hx)
hx=self.pool5(hx5)
hx6=self.rebnconv6(hx)
hx7=self.rebnconv7(hx6)
hx6d=self.rebnconv6d(tf.concat((hx7,hx6),axis=3))
hx6d_up=UpSampling2D((2,2),interpolation="bilinear")(hx6d) #上采样
hx5d=self.rebnconv5d(tf.concat((hx6d_up,hx5),axis=3))
hx5d_up=UpSampling2D((2,2),interpolation="bilinear")(hx5d)
hx4d=self.rebnconv4d(tf.concat((hx5d_up,hx4),axis=3))
hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)
hx3d=self.rebnconv3d(tf.concat((hx4d_up,hx3),axis=3))
hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)
hx2d=self.rebnconv2d(tf.concat((hx3d_up,hx2),axis=3))
hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)
hx1d=self.rebnconv1d(tf.concat((hx2d_up,hx1),axis=3))
return hx1d+h0
#第二个RSU模块
class RSU6(Model):
def __init__(self,mid_ch=12,out_ch=3):
super(RSU6, self).__init__()
self.rebnconv0=REBNCONV(out_ch,dirate=1)
self.rebnconv1=REBNCONV(mid_ch,dirate=1)
self.pool1=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv2=REBNCONV(mid_ch,dirate=1)
self.pool2=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv3=REBNCONV(mid_ch,dirate=1)
self.pool3=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv4=REBNCONV(mid_ch,dirate=1)
self.pool4=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv5=REBNCONV(mid_ch,dirate=1)
self.rebnconv6=REBNCONV(mid_ch,dirate=2)
self.rebnconv5d=REBNCONV(mid_ch,dirate=1)
self.rebnconv4d=REBNCONV(mid_ch,dirate=1)
self.rebnconv3d=REBNCONV(mid_ch,dirate=1)
self.rebnconv2d=REBNCONV(mid_ch,dirate=1)
self.rebnconv1d=REBNCONV(out_ch,dirate=1)
def call(self, inputs, training=None, mask=None):
hx0=self.rebnconv0(inputs)
hx1=self.rebnconv1(hx0)
hx=self.pool1(hx1)
hx2=self.rebnconv2(hx)
hx=self.pool2(hx2)
hx3=self.rebnconv3(hx)
hx=self.pool3(hx3)
hx4=self.rebnconv4(hx)
hx=self.pool4(hx4)
hx5=self.rebnconv5(hx)
hx6=self.rebnconv6(hx5)
hx5d=self.rebnconv5d(tf.concat((hx6,hx5),axis=3))
hx5d_up= UpSampling2D((2,2),interpolation="bilinear")(hx5d)
# print(hx5d_up.shape)
# print(hx4.shape)
hx4d=self.rebnconv4d(tf.concat((hx5d_up,hx4),axis=3))
hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)
hx3d=self.rebnconv3d(tf.concat((hx4d_up,hx3),axis=3))
hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)
hx2d=self.rebnconv2d(tf.concat((hx3d_up,hx2),axis=3))
hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)
hx1d=self.rebnconv1d(tf.concat((hx2d_up,hx1),axis=3))
return hx1d+hx0
#第三个RSU模块
class RSU5(Model):
def __init__(self, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconv0 = REBNCONV(out_ch, dirate=1)
self.rebnconv1 = REBNCONV(mid_ch, dirate=1)
self.pool1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv2 = REBNCONV(mid_ch, dirate=1)
self.pool2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv3 = REBNCONV(mid_ch, dirate=1)
self.pool3 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv4 = REBNCONV(mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(out_ch, dirate=1)
def call(self, inputs, training=None, mask=None):
hx0 = self.rebnconv0(inputs)
hx1 = self.rebnconv1(hx0)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(tf.concat((hx5, hx4), axis=3))
hx4d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx4d)
hx3d = self.rebnconv3d(tf.concat((hx4d_up, hx3), axis=3))
hx3d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx3d)
hx2d = self.rebnconv2d(tf.concat((hx3d_up, hx2), axis=3))
hx2d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx2d)
hx1d = self.rebnconv1d(tf.concat((hx2d_up, hx1), axis=3))
return hx1d + hx0
#第四个RSU模块
class RSU4(Model):
def __init__(self, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconv0 = REBNCONV(out_ch, dirate=1)
self.rebnconv1 = REBNCONV(mid_ch, dirate=1)
self.pool1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv2 = REBNCONV(mid_ch, dirate=1)
self.pool2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv3 = REBNCONV(mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(out_ch, dirate=1)
def call(self, inputs, training=None, mask=None):
hx0 = self.rebnconv0(inputs)
hx1 = self.rebnconv1(hx0)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx)
hx3d = self.rebnconv3d(tf.concat((hx4, hx3), axis=3))
hx3d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx3d)
hx2d = self.rebnconv2d(tf.concat((hx3d_up, hx2), axis=3))
hx2d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx2d)
hx1d = self.rebnconv1d(tf.concat((hx2d_up, hx1), axis=3))
return hx1d + hx0
4. Модуль расширения
После 4-х модулей RSU разрешение карты признаков становится очень маленьким ( при входном разрешении 320 320 здесь разрешение 18 18), если продолжать даунсэмплинг, потеряется больше информации, поэтому автор здесь заменяет конкат и операции повышения дискретизации с расширенной сверткой, то есть карты признаков всех промежуточных слоев в расширенном модуле имеют то же разрешение, что и входные карты признаков. код показывает, как показано ниже:
#扩展模块
class RSU4F(Model):
def __init__(self,mid_ch=12,out_ch=3):
super(RSU4F, self).__init__()
self.rebnconv0=REBNCONV(out_ch,dirate=1)
self.rebnconv1=REBNCONV(mid_ch,dirate=1)
self.rebnconv2=REBNCONV(mid_ch,dirate=2)
self.rebnconv3=REBNCONV(mid_ch,dirate=4)
self.rebnconv4=REBNCONV(mid_ch,dirate=8)
self.rebnconv3d=REBNCONV(mid_ch,dirate=4)
self.rebnconv2d=REBNCONV(mid_ch,dirate=2)
self.rebnconv1d=REBNCONV(out_ch,dirate=1)
def call(self, inputs, training=None, mask=None):
hx0=self.rebnconv0(inputs)
hx1=self.rebnconv1(hx0)
hx2=self.rebnconv2(hx1)
hx3=self.rebnconv3(hx2)
hx4=self.rebnconv4(hx3)
hx3d=self.rebnconv3d(tf.concat((hx4,hx3),axis=3))
hx2d=self.rebnconv2d(tf.concat((hx3d,hx2),axis=3))
hx1d=self.rebnconv1d(tf.concat((hx2d,hx1),axis=3))
return hx1d+hx0
5. Общая структура
Здесь следует отметить, что при получении карты вероятностей каждого слоя используется сигмовидная функция активации, а общий код реализации сети выглядит следующим образом:
#U^2Net
class U2NET(Model):
def __init__(self,out_ch=1):
super(U2NET, self).__init__()
#encode
self.stage1=RSU7(32,64)
self.pool1_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#144*144
self.stage2=RSU6(32,128)
self.pool2_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#72*72
self.stage3=RSU5(64,256)
self.pool3_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#36*36
self.stage4=RSU4(128,512)
self.pool4_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#18*18
self.stage5=RSU4F(256,512)
self.pool5_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#9*9
self.stage6=RSU4F(256,512)
#decode
self.stage5d=RSU4F(256,512)
self.stage4d=RSU4(128,256)
self.stage3d=RSU5(64,128)
self.stage2d=RSU6(32,64)
self.stage1d=RSU7(16,64)
#每个层的输出
self.side1=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side2=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side3=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side4=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side5=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side6=Conv2D(out_ch,kernel_size=(3,3),padding="same")
#最终输出
self.outconv=Conv2D(out_ch,kernel_size=(1,1))
def call(self, inputs, training=None, mask=None):
hx1=self.stage1(inputs)
hx=self.pool1_1(hx1)
hx2=self.stage2(hx)
hx=self.pool2_1(hx2)
hx3=self.stage3(hx)
hx=self.pool3_1(hx3)
hx4=self.stage4(hx)
hx=self.pool4_1(hx4)
hx5=self.stage5(hx)
hx=self.pool5_1(hx5)
hx6=self.stage6(hx)
hx6_up=UpSampling2D((2,2),interpolation="bilinear")(hx6)
#decode
hx5d=self.stage5d(tf.concat((hx6_up,hx5),axis=3))
hx5d_up=UpSampling2D((2,2),interpolation="bilinear")(hx5d)
hx4d=self.stage4d(tf.concat((hx5d_up,hx4),axis=3))
hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)
hx3d=self.stage3d(tf.concat((hx4d_up,hx3),axis=3))
hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)
hx2d=self.stage2d(tf.concat((hx3d_up,hx2),axis=3))
hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)
hx1d=self.stage1d(tf.concat((hx2d_up,hx1),axis=3))
# side out
d1=self.side1(hx1d)
d2=self.side2(hx2d)
d2=UpSampling2D((2,2),interpolation="bilinear")(d2)
d3=self.side3(hx3d)
d3=UpSampling2D((4,4),interpolation="bilinear")(d3)
d4=self.side4(hx4d)
d4=UpSampling2D((8,8),interpolation="bilinear")(d4)
d5=self.side5(hx5d)
d5=UpSampling2D((16,16),interpolation="bilinear")(d5)
d6=self.side6(hx6)
d6=UpSampling2D((32,32),interpolation="bilinear")(d6)
out=self.outconv(tf.concat((d1,d2,d3,d4,d5,d6),axis=3))
sig=k.activations.sigmoid #定义激活函数
return sig(out),sig(d1),sig(d2),sig(d3),sig(d4),sig(d5),sig(d6)
Выше приведено полное содержание этой статьи.Если вам нужно обучить код или найти ошибки в статье, оставьте сообщение в области комментариев.