UNet代码详解
第一步,还是加载一些库
import torch
import torch.nn as nn
import torch.nn.functional as F
创建一个卷积Block类
class UNetConvBlock(nn.Module):
def __init__(self, in_chans, out_chans, padding, batch_norm):
super(UNetConvBlock, self).__init__()
block=[]
block.append(nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=int(padding))
block.append(nn.ReLU())
if batch_norm :
block.append(nn.BatchNorm2d(out_chans))
block.append(nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=int(padding))
block.append(nn.ReLU())
if batch_norm:
block.append(nn.BatchNorm2d(out_chans))
self.block = nn.Sequential(*block)
def forward(self, x):
out = self.block(x)
return out
这里实现的就是,每一个stage的卷积block。如下图:
创建上采样的Block
class UNetUpBlock(nn.Module):
def __init__(self, in_chans, out_chans, up_mode, padding, batch_norm):
super(UNetUpBlock, self).__init__()
if up_mode == 'upconv':
self.up = nn.ConvTransposed2d(in_chans, out_chans, kernel_size=2, stride=2)
elif up_mode=='upsample':
self.up == nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
nn.Conv2d(in_chans, out_chans, kernel_size=1),
)
self.conv_block = UNetConvBlock(in_chans, out_chans, padding, batch_norm)
上采样有两种方式,转置卷积和双线性插值。这里可以选择,使用哪种方式实现。
def centre_crop(self, layer, target_size):
_,_,layer_height, layer_width = layer.size()
diff_y = (layer_height - target_size[0]) // 2
diff_x = (layer_width - target_size[1]) // 2
return layer[:, :, diff_y: (diff_y + target_size[0]), diff_x: (diff_x + target_size[1])]
这里实现的是剪裁操作,我们注意到,skip connection两边图像大小是不一样的,根据论文描述,我们需要将Encoder部分的图像剪裁到Decoder部分大小,如图所示:
def forward(self, x, bridge):
up = self.up(x)
crop1 = self.centre_crop(bridge, up.shape[2:])
out = torch.cat([up, crop1], 1)
out = self.conv_block(out)
return out
创建UNet
class UNet(nn.Module):
def __init__(
self,
in_channels=1,
n_classes=2,
depth=5,
wf=6,
padding=False,
batch_norm=False,
up_mode='upconv'
):
super(UNet, self).__init__()
assert up_mode in ('upconv', 'upsample')
self.padding = padding
self.depth = depth
prev_channels = in_channels
self.down_path = nn.ModuleList()
for i in range(depth): # 0 1 2 3 4
self.down_path.append(UNetConvBlock(prev_channels, 2**(wf + i), padding, batch_norm)
prev_channels=2**(wf+i) # 这里wf+i计算channels数量
self.up_path = nn.ModuleList()
for i in resersed(range(depth-1)):
self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode, padding, batch_norm)
prev_channels = 2**(wf+i)
self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
def forward(self, x):
blocks=[]
for i, down in enumerate(self.down_path):
x = down(x)
if i != len(self.down_path) - 1:
blocks.append(x)
x = F.max_pool2d(x ,2)
for i, up in enumerate(self.up_path):
x = up(x, blocks[-i -1])
return self.last(x)