Unetはすでに非常に古いセグメンテーションモデルです。これは、2015年の「U-Net:生物医学画像セグメンテーションのための畳み込みネットワーク」で提案されたモデルです。
紙のリンク:https://arxiv.org/abs/1505.04597
Unet以前は古いFCNネットワークでした。FCNはFullyConvolutionalNetowkrsの略で、ネットワークを分割するための基本的なフレームワークを確立しますが、FCNネットワークの精度は低く、Unetほど使いやすいものではありません。 。
Unetネットワークは非常に単純で、前半は特徴抽出で、後半はアップサンプリングです。一部の文献では、この構造はエンコーダ-デコーダ構造と呼ばれています。ネットワークの全体的な構造はより大きな英字Uであるため、U-netと呼ばれます。
ネットワーク構造は次のとおりです。
- エンコーダー:左半分は、2つの3x3畳み込み層(RELU)と2x2の最大プーリング層で構成され、ダウンサンプリングモジュールを形成します(後のコードで確認できます)。
- デコーダー:半分の部分があり、アップサンプリング畳み込み層(デコンボリューション層)+フィーチャーステッチング畳み込み+ 2つの3x3畳み込み層(ReLU)で構成されています(コードで確認できるように)。この種のパススルーチャネル数スプライシングを使用すると、より多くの機能を取得できますが、より多くのメモリを消費します。
UNetの構造は、低レベルの機能と高レベルの機能の情報を組み合わせることができるように設計されています。
低レベル(ディープ)機能:複数のダウンサンプリング後の低解像度情報。これは、画像全体のセグメンテーションターゲットのコンテキストセマンティック情報を提供できます。これは、ターゲットとその環境との関係を反映する機能として理解できます。この機能は、オブジェクトカテゴリの判断に役立ちます(したがって、分類の問題は通常、低解像度/詳細な情報のみを必要とし、マルチスケールの融合は含まれません)
高レベル(浅い)機能:連結操作を介して、エンコーダーから同じ高さのデコーダーに直接渡される高解像度の情報。グラデーションなど、セグメンテーションのためのより洗練された機能を提供できます。
サイズの非互換性の理由について:
写真から、左右の寸法が正しくないことがわかりますので、合わせたい場合はトリミングが必要です。灰色の矢印はすべてコピーアンドクロップですが、再現されたものはありません。モデルも用意されています。このように、左右のサイズを同じに設定し、各畳み込みにパディングを追加して、畳み込み後にサイズが変わらないようにします。
短所:
-
ネットワークの動作が非常に遅い。ネットワークは近隣ごとに1回実行され、重複する近隣に対して操作を繰り返します。
-
ネットワークは、正確なローカリゼーションとコンテキスト情報の取得の間でトレードオフを行う必要があります。パッチが大きいほど、より多くの最大プーリングレイヤーが必要になり、ローカリゼーションの精度が低下します。一方、近隣が小さいと、ネットワークが取得するコンテキスト情報が少なくなります。
UNetコード(pytorch)
import torch.nn as nn
import torch
from torch import autograd
#把常用的2个卷积操作简单封装下
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch), #添加了BN层
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self, in_ch, out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
# 逆卷积,也可以使用上采样(保证k=stride,stride即上采样倍数)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64, out_ch, 1)
def forward(self, x):
c1 = self.conv1(x)
p1 = self.pool1(c1)
c2 = self.conv2(p1)
p2 = self.pool2(c2)
c3 = self.conv3(p2)
p3 = self.pool3(c3)
c4 = self.conv4(p3)
p4 = self.pool4(c4)
c5 = self.conv5(p4)
up_6 = self.up6(c5)
merge6 = torch.cat([up_6, c4], dim=1)
c6 = self.conv6(merge6)
up_7 = self.up7(c6)
merge7 = torch.cat([up_7, c3], dim=1)
c7 = self.conv7(merge7)
up_8 = self.up8(c7)
merge8 = torch.cat([up_8, c2], dim=1)
c8 = self.conv8(merge8)
up_9 = self.up9(c8)
merge9 = torch.cat([up_9, c1], dim=1)
c9 = self.conv9(merge9)
c10 = self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
Unetコード(Keras)
def unet(pretrained_weights=None, input_size=(256, 256, 3)):
inputs = Input(input_size)
conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
drop4 = Dropout(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
drop5 = Dropout(0.5)(conv5)
up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(drop5))
merge6 = concatenate([drop4, up6], axis=3)
conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv6))
merge7 = concatenate([conv3, up7], axis=3)
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv7))
merge8 = concatenate([conv2, up8], axis=3)
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
UpSampling2D(size=(2, 2))(conv8))
merge9 = concatenate([conv1, up9], axis=3)
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)
model = Model(inputs=inputs, outputs=conv10)
model.summary()
if (pretrained_weights):
model.load_weights(pretrained_weights)
return model