序文
U^2Net は以前使ったことのある画像分割ネットワークです。効果が良くて感動しました。学習も兼ねて、このネットワークを TensorFlow2 で再構築しました。忘れたときのためにメモしておきます。関連文書へのリンクは次のとおりです:
紙アドレス公式コードの詳細な解釈
(pytorch)
1. 環境構築
pip install tensorflow-gpu
2. ネットワーク構造
1. ネットワーク構成図
ネットワーク構造図は次のとおりです。
2. ヘッダーファイルをインポートする
import tensorflow.keras as k
import tensorflow as tf
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Conv2D,MaxPool2D,BatchNormalization,ReLU,UpSampling2D
3. RSUの構造
RSU デカップリング ユニットは小さな unet 構造に相当し、前半では一連のダウンサンプリングを通じて特徴を抽出し、後半ではアップサンプリングと連結を通じて特徴の融合を実現します。ここで連続する 4 つの RSU モジュールは実際にはもっと簡略化して記述することができることに注意してください。ここではすべてを直接展開して記述します。コードは次のとおりです。
#基本卷积块
class REBNCONV(Model):
def __init__(self,out_ch=3,dirate=1):
super(REBNCONV, self).__init__()
self.conv=Sequential()
self.conv.add(Conv2D(out_ch,kernel_size=(3,3),strides=(1,1),padding="same",dilation_rate=dirate))
self.conv.add(BatchNormalization())
self.conv.add(ReLU())
def call(self, inputs, training=None, mask=None):
x=self.conv(inputs)
return x
#第一个RSU结构
class RSU7(Model):
def __init__(self,mid_ch=12,out_ch=3):
super(RSU7, self).__init__()
self.rebnconv0=REBNCONV(out_ch,dirate=1)
self.rebnconv1=REBNCONV(mid_ch,dirate=1)
self.pool1=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv2=REBNCONV(mid_ch,dirate=1)
self.pool2=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv3=REBNCONV(mid_ch,dirate=1)
self.pool3=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv4=REBNCONV(mid_ch,dirate=1)
self.pool4=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv5=REBNCONV(mid_ch,dirate=1)
self.pool5=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv6=REBNCONV(mid_ch,dirate=1)
self.rebnconv7=REBNCONV(mid_ch,dirate=2)
self.rebnconv6d=REBNCONV(mid_ch,dirate=1)
self.rebnconv5d=REBNCONV(mid_ch,dirate=1)
self.rebnconv4d=REBNCONV(mid_ch,dirate=1)
self.rebnconv3d=REBNCONV(mid_ch,dirate=1)
self.rebnconv2d=REBNCONV(mid_ch,dirate=1)
self.rebnconv1d=REBNCONV(out_ch,dirate=1)
def call(self, inputs, training=None, mask=None):
h0=self.rebnconv0(inputs)
hx1=self.rebnconv1(h0)
hx=self.pool1(hx1)
hx2=self.rebnconv2(hx)
hx=self.pool2(hx2)
hx3=self.rebnconv3(hx)
hx=self.pool3(hx3)
hx4=self.rebnconv4(hx)
hx=self.pool4(hx4)
hx5=self.rebnconv5(hx)
hx=self.pool5(hx5)
hx6=self.rebnconv6(hx)
hx7=self.rebnconv7(hx6)
hx6d=self.rebnconv6d(tf.concat((hx7,hx6),axis=3))
hx6d_up=UpSampling2D((2,2),interpolation="bilinear")(hx6d) #上采样
hx5d=self.rebnconv5d(tf.concat((hx6d_up,hx5),axis=3))
hx5d_up=UpSampling2D((2,2),interpolation="bilinear")(hx5d)
hx4d=self.rebnconv4d(tf.concat((hx5d_up,hx4),axis=3))
hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)
hx3d=self.rebnconv3d(tf.concat((hx4d_up,hx3),axis=3))
hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)
hx2d=self.rebnconv2d(tf.concat((hx3d_up,hx2),axis=3))
hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)
hx1d=self.rebnconv1d(tf.concat((hx2d_up,hx1),axis=3))
return hx1d+h0
#第二个RSU模块
class RSU6(Model):
def __init__(self,mid_ch=12,out_ch=3):
super(RSU6, self).__init__()
self.rebnconv0=REBNCONV(out_ch,dirate=1)
self.rebnconv1=REBNCONV(mid_ch,dirate=1)
self.pool1=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv2=REBNCONV(mid_ch,dirate=1)
self.pool2=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv3=REBNCONV(mid_ch,dirate=1)
self.pool3=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv4=REBNCONV(mid_ch,dirate=1)
self.pool4=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv5=REBNCONV(mid_ch,dirate=1)
self.rebnconv6=REBNCONV(mid_ch,dirate=2)
self.rebnconv5d=REBNCONV(mid_ch,dirate=1)
self.rebnconv4d=REBNCONV(mid_ch,dirate=1)
self.rebnconv3d=REBNCONV(mid_ch,dirate=1)
self.rebnconv2d=REBNCONV(mid_ch,dirate=1)
self.rebnconv1d=REBNCONV(out_ch,dirate=1)
def call(self, inputs, training=None, mask=None):
hx0=self.rebnconv0(inputs)
hx1=self.rebnconv1(hx0)
hx=self.pool1(hx1)
hx2=self.rebnconv2(hx)
hx=self.pool2(hx2)
hx3=self.rebnconv3(hx)
hx=self.pool3(hx3)
hx4=self.rebnconv4(hx)
hx=self.pool4(hx4)
hx5=self.rebnconv5(hx)
hx6=self.rebnconv6(hx5)
hx5d=self.rebnconv5d(tf.concat((hx6,hx5),axis=3))
hx5d_up= UpSampling2D((2,2),interpolation="bilinear")(hx5d)
# print(hx5d_up.shape)
# print(hx4.shape)
hx4d=self.rebnconv4d(tf.concat((hx5d_up,hx4),axis=3))
hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)
hx3d=self.rebnconv3d(tf.concat((hx4d_up,hx3),axis=3))
hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)
hx2d=self.rebnconv2d(tf.concat((hx3d_up,hx2),axis=3))
hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)
hx1d=self.rebnconv1d(tf.concat((hx2d_up,hx1),axis=3))
return hx1d+hx0
#第三个RSU模块
class RSU5(Model):
def __init__(self, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconv0 = REBNCONV(out_ch, dirate=1)
self.rebnconv1 = REBNCONV(mid_ch, dirate=1)
self.pool1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv2 = REBNCONV(mid_ch, dirate=1)
self.pool2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv3 = REBNCONV(mid_ch, dirate=1)
self.pool3 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv4 = REBNCONV(mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(out_ch, dirate=1)
def call(self, inputs, training=None, mask=None):
hx0 = self.rebnconv0(inputs)
hx1 = self.rebnconv1(hx0)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(tf.concat((hx5, hx4), axis=3))
hx4d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx4d)
hx3d = self.rebnconv3d(tf.concat((hx4d_up, hx3), axis=3))
hx3d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx3d)
hx2d = self.rebnconv2d(tf.concat((hx3d_up, hx2), axis=3))
hx2d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx2d)
hx1d = self.rebnconv1d(tf.concat((hx2d_up, hx1), axis=3))
return hx1d + hx0
#第四个RSU模块
class RSU4(Model):
def __init__(self, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconv0 = REBNCONV(out_ch, dirate=1)
self.rebnconv1 = REBNCONV(mid_ch, dirate=1)
self.pool1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv2 = REBNCONV(mid_ch, dirate=1)
self.pool2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv3 = REBNCONV(mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(out_ch, dirate=1)
def call(self, inputs, training=None, mask=None):
hx0 = self.rebnconv0(inputs)
hx1 = self.rebnconv1(hx0)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx)
hx3d = self.rebnconv3d(tf.concat((hx4, hx3), axis=3))
hx3d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx3d)
hx2d = self.rebnconv2d(tf.concat((hx3d_up, hx2), axis=3))
hx2d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx2d)
hx1d = self.rebnconv1d(tf.concat((hx2d_up, hx1), axis=3))
return hx1d + hx0
4. 拡張モジュール
4 つの RSU モジュールの後、特徴マップの解像度は非常に小さくなります (入力解像度が 320 320 の場合、ここでの解像度は 18 18 です)。ダウンサンプリングを続けると、より多くの情報が失われるため、著者はここで concat を置き換えます。拡張畳み込みを使用したアップサンプリング操作。つまり、拡張モジュール内のすべての中間層の特徴マップは、入力特徴マップと同じ解像度を持ちます。コードは以下のように表示されます。
#扩展模块
class RSU4F(Model):
def __init__(self,mid_ch=12,out_ch=3):
super(RSU4F, self).__init__()
self.rebnconv0=REBNCONV(out_ch,dirate=1)
self.rebnconv1=REBNCONV(mid_ch,dirate=1)
self.rebnconv2=REBNCONV(mid_ch,dirate=2)
self.rebnconv3=REBNCONV(mid_ch,dirate=4)
self.rebnconv4=REBNCONV(mid_ch,dirate=8)
self.rebnconv3d=REBNCONV(mid_ch,dirate=4)
self.rebnconv2d=REBNCONV(mid_ch,dirate=2)
self.rebnconv1d=REBNCONV(out_ch,dirate=1)
def call(self, inputs, training=None, mask=None):
hx0=self.rebnconv0(inputs)
hx1=self.rebnconv1(hx0)
hx2=self.rebnconv2(hx1)
hx3=self.rebnconv3(hx2)
hx4=self.rebnconv4(hx3)
hx3d=self.rebnconv3d(tf.concat((hx4,hx3),axis=3))
hx2d=self.rebnconv2d(tf.concat((hx3d,hx2),axis=3))
hx1d=self.rebnconv1d(tf.concat((hx2d,hx1),axis=3))
return hx1d+hx0
5. 全体構成
ここで、各層の確率マップを取得する際に使用される活性化関数はシグモイドであり、ネットワーク全体の実装コードは次のとおりであることに注意してください。
#U^2Net
class U2NET(Model):
def __init__(self,out_ch=1):
super(U2NET, self).__init__()
#encode
self.stage1=RSU7(32,64)
self.pool1_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#144*144
self.stage2=RSU6(32,128)
self.pool2_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#72*72
self.stage3=RSU5(64,256)
self.pool3_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#36*36
self.stage4=RSU4(128,512)
self.pool4_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#18*18
self.stage5=RSU4F(256,512)
self.pool5_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#9*9
self.stage6=RSU4F(256,512)
#decode
self.stage5d=RSU4F(256,512)
self.stage4d=RSU4(128,256)
self.stage3d=RSU5(64,128)
self.stage2d=RSU6(32,64)
self.stage1d=RSU7(16,64)
#每个层的输出
self.side1=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side2=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side3=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side4=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side5=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side6=Conv2D(out_ch,kernel_size=(3,3),padding="same")
#最终输出
self.outconv=Conv2D(out_ch,kernel_size=(1,1))
def call(self, inputs, training=None, mask=None):
hx1=self.stage1(inputs)
hx=self.pool1_1(hx1)
hx2=self.stage2(hx)
hx=self.pool2_1(hx2)
hx3=self.stage3(hx)
hx=self.pool3_1(hx3)
hx4=self.stage4(hx)
hx=self.pool4_1(hx4)
hx5=self.stage5(hx)
hx=self.pool5_1(hx5)
hx6=self.stage6(hx)
hx6_up=UpSampling2D((2,2),interpolation="bilinear")(hx6)
#decode
hx5d=self.stage5d(tf.concat((hx6_up,hx5),axis=3))
hx5d_up=UpSampling2D((2,2),interpolation="bilinear")(hx5d)
hx4d=self.stage4d(tf.concat((hx5d_up,hx4),axis=3))
hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)
hx3d=self.stage3d(tf.concat((hx4d_up,hx3),axis=3))
hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)
hx2d=self.stage2d(tf.concat((hx3d_up,hx2),axis=3))
hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)
hx1d=self.stage1d(tf.concat((hx2d_up,hx1),axis=3))
# side out
d1=self.side1(hx1d)
d2=self.side2(hx2d)
d2=UpSampling2D((2,2),interpolation="bilinear")(d2)
d3=self.side3(hx3d)
d3=UpSampling2D((4,4),interpolation="bilinear")(d3)
d4=self.side4(hx4d)
d4=UpSampling2D((8,8),interpolation="bilinear")(d4)
d5=self.side5(hx5d)
d5=UpSampling2D((16,16),interpolation="bilinear")(d5)
d6=self.side6(hx6)
d6=UpSampling2D((32,32),interpolation="bilinear")(d6)
out=self.outconv(tf.concat((d1,d2,d3,d4,d5,d6),axis=3))
sig=k.activations.sigmoid #定义激活函数
return sig(out),sig(d1),sig(d2),sig(d3),sig(d4),sig(d5),sig(d6)
以上がこの記事の内容です。コードのトレーニングが必要な場合や記事内の間違いを見つけた場合は、コメント欄にメッセージを残してください。