よりシンプルかつ自然に、より近いです。
参考資料
U-ネット:生物医学画像セグメンテーションのための畳み込みネットワーク
アブストラクト&はじめに
いくつかの論文のキーワードがあります。
- パス収縮パスを収縮。
- 拡張パス膨張路と
- 正確な定位より正確な位置情報。
- オーバーラップタイル境界ミラー反転。
- ランダム弾性変形ランダム弾性変形します。
- 不変のスケール不変性;
- 細胞をタッチすると、近接した二つ細胞をいいます。
- シームレス耕しシームレススプライシング。
まあ、そんなに私たちは、この論文を見てみましょうこれらのキーワードについて、本論文と彼の構造が、同じくらい簡単で理解しやすいです非常に語っています。
まず、主な著者との比較を行うためにスライディングウィンドウ法に基づいて、ネットワークを取る、著者は最初の瞬間をディスこの方法の問題があります:
電子顕微鏡像における深いニューラルネットワークセグメントニューロン膜(NIPS2012)
- 非常に遅く、計算冗長性(私たちはすべてを知っている問題のスライディングウィンドウ)。
- より多くの機能がより最大プーリングを意味し、それはより多くの位置情報を失うことになるので、位置精度や特徴抽出のバランスが、あります。
著者の論文の特徴は、以下に触発された多層のアイデアを入力します。
- オブジェクトセグメンテーションときめ細かい局在化のためのハイパーコラム(2014)
- カスケード接続された階層モデルとロジスティック論理和標準ネットワークによる画像セグメンテーション(2013)
这两篇论文指出把多层特征(the features from multiple layers)输入到classifier能够得到更好的特征提取和更好的位置信息(good localization and the use of context are possible at the same time)。
U-Net和其他网络的不同之处在于,上采样(Upsampling)过程中也有很多维特征,让特征流向更高分辨率的卷积层。
由于网络使用的卷积是3x3 unpadded convolutions,所以特征图会缩小,为了让输出的图像和输入图像的大小无缝拼接(seamless tilling),则要用到边界镜像翻转(overlap-tile),具体做法如下图:
Architecture
网络结构
使用3x3 unpadded convolutions,所以特征图会不断缩小,在横向拼接特征的时候,也要对特征图进行裁剪,以保持特征图大小一致。
全部使用ReLU激活函数。
权值初始化使用何恺明的方法:
Surpassing humanlevel performance on imagenet classification
具体做法就是一个标准差满足sqrt(2/N)
的高斯分布,其中的N代表一个神经元的输入节点数(例如一个3x3卷积核的输入是64维的话,那么N=9x64=576)
训练
在训练时作者更倾向于更大的图像输入,所以干脆将batch_size设置为1,所以在优化器的使用方面,使用到了带有动量的优化器,并且动量设置的很大(0.99),这么做是为了让以前的样本可以决定当前梯度更新的方向(因为batch_size太小啦,可以理解)。
损失函数就是pixel-wise soft-max + cross_entropy了。
数据增强
随机弹性形变和weight map:
ランダム弾性変形が10pixelランダム初期変位ベクトルのガウス分布の標準偏差から、ランダム変形3x3の粗いグリッドを初期化するために使用され、次いで、バイキュービック双三次補間は、各画素の変位を算出します。
目的は、ランダム弾性変形ネットワークは不変性(スケール不変性)を有することを可能にすることです。
以下に示すように、ウェイトマップは、これらの細胞は、重量損失関数に高い重みをタッチバックグラウンドとの間に位置し、接触細胞との間のバックグラウンドを学習するためにネットワークを強制することです。
次のように詳細なウェイトマップが計算されます。
コード
:最後に、コードそれを見https://github.com/milesial/Pytorch-UNet
全体のモデル:
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = inconv(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return F.sigmoid(x)
詳細:
class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
self.conv = double_conv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool2d(2),
double_conv(in_ch, out_ch)
)
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(up, self).__init__()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
self.conv = double_conv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
トレーニング:
optimizer = optim.SGD(net.parameters(),
lr=lr,
momentum=0.9,
weight_decay=0.0005)
criterion = nn.BCELoss()