U-Net网络结构解析和代码解析

U-Net网络结构详解

在语义分割领域,基于深度学习的语义分割算法开山之作是FCN(Fully Convolutional Networks for Semantic Segmentation),而U-Net是遵循FCN的原理,并进行了相应的改进,使其适应小样本的简单分割问题。U-Net网络在医疗影像领域的应用十分广泛,成为了大多数医疗影像语义分割任务的baseline,同时基于U-Net网络改进网络也纷纷出现,本篇文章主要介绍U-NET网络。

由于医学图像往往包含噪声且边界模糊,仅靠低层次的图像特征难以进行目标检测。同时,由于缺乏图像的细节信息,仅靠图像语义特征无法得到准确的边界。而U-Net通过跳跃连接,将低分辨率和高分辨率的特征映射结合起来,有效地融合了低层次和高层次的图像特征,从而成为医学图像分割任务的一个理想解决方案。目前,U-Net已经成为了大多数医学图像分割任务的一个基准,并且激发了很多有意义的改进方法,其网络结构下图所示。

![][pt_01]

U-Net是一个全卷积神经网络,输入与输出都是图像,没有全连接层;并且由图可知,U-Net在宏观上是一个对称的网络结构,左侧为下采样,右侧为上采样,同时按照功能可以将左侧的一系列下采样操作称为encoder,将右侧的一系列上采样操作称为decoder,因此U-Net网络可以划分到Encoder-decoder基础模型类型中;该网络最主要的两个特点是:U型网络结构和Skip Connection跳层连接。

Skip Connection跳层连接中间四条灰色的箭头copy and crop,Skip Connection是在上采样的过程中,融合下采样过过程中的feature map。

Skip Connection跳层连接用到的融合的操作也很简单,就是将feature map的通道进行叠加,俗称Concat。例如,一个大小为256×256×64的feature map,即feature map的w(宽)为256,h(高)为256,c(通道数)为64;和一个大小为256×256×32的feature map进行Concat融合,就会得到一个大小为256×256×96的feature map。

在实际使用中,Concat融合的两个feature map的大小不一定相同,例如256×256×64的feature map和240×240×32的feature map进行Concat。解决这个问题有两种办法:

  • 第一种:将大256×256×64的feature map进行裁剪,裁剪为240×240×64的feature map,比如上下左右,各舍弃8 pixel,裁剪后再进行Concat,得到240×240×96的feature map。

  • 第二种:将小240×240×32的feature map进行padding操作,padding为256×256×32的feature map,比如上下左右,各补8 pixel,padding后再进行Concat,得到256×256×96的feature map。

U-Net网络核心思想:

  • 不含全连接层(fc)的全卷积(fully conv)网络。可适应任意尺寸输入。
  • 增大数据尺寸的反卷积(deconv)层。能够输出精细的结果。
  • 结合不同深度层结果的跳级(skip)结构。同时确保鲁棒性和精确性。

这里使用1×1的卷积替代全连接层还有一个好处:输入的图片形状不再固定了。由于全连接层的输入必须固定形状的,所以输入模型的图片一般都要先resize到固定的shape,而使用1×1卷积代替全连接层之后变不在存在这一问题。在推理的时候,不需要再对图片进行resize,从而最好可能会导致输出的图片的失真。

扫描二维码关注公众号,回复: 15101450 查看本文章

这么一个不断加深网络并不断增加通道数来提取浅层信息和深层特征的过程就是编码器 (Encoder)

U-Net未能解决的一些问题:

  • 组织器官的顶层截面和底层截面与中部截面差异过大而不易识别;
  • 不同扫描影像之间有较大的外观变异而不易识别;
  • 磁场不均匀引起的伪影和畸变,导致不易识别。

U-Net网络架构实现代码解析

将U-Net网络中的架构分解为四个模块:

  1. 输入层的DoubleConv模块;
  2. 左侧分支从第二层开始的max_pool+DoubleConv,称为Down模块;
  3. 右侧分支的up_conv+copy_crop+DoubleConv,称为Up模块;
  4. 输出层的1x1卷积,称为OutConv模块。

在这里插入图片描述

从上图可以看到,Unet网络的结构比较简单,左侧分支每一层包含两个重复的卷积,命名为DoubleConv;从第二层开始,都是max pool + DoubleConv;右侧分支每一层都是up conv + copy crop + DoubleConv;在最后输出层,有一个1x1 conv。

1. 模块实现

1.1 DoubleConv模块

DoubleConv模块由两个“Conv2d+NatchNorm2d+ReLU”组成:

# unet_parts.py
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
 
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
 
    def forward(self, x):
        return self.double_conv(x)

1.2 Down模块

Down模块由一个“MaxPool2d+DoubleConv”组成:

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
 
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
 
    def forward(self, x):
        return self.maxpool_conv(x)

1.3 Up模块

右侧上行模块涉及到copy and crop,实现起来会略微复杂一些。首先经过一个上采样或转置卷积,然后从左侧路径的同一层feature map中截取相同的size(从图中很容易可以看出,左侧同一层中的feature map比右侧的size要大一些),与右侧feature map合并,最后再进行DoubleConv。代码如下:

class Up(nn.Module):
    """Upscaling then double conv"""
 
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
 
        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
 
        self.conv = DoubleConv(in_channels, out_channels)
 
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
 
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
 
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
  • Upsample通过插值方法完成上采样,不需要训练参数。
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

其中mode为选择的插值方法,双线性插值如下图所示:

![][pt_04]

已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。

对于一个feature map而言,其实就是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。

  • ConvTranspose2d可以理解为卷积的逆过程,可以训练参数。
nn.ConvTranspose2d(mid_channels, mid_channels, kernel_size=4, stride=2, padding=1)

其中输出尺寸与输入关系如下,所以,k=4, s=2, p=1即2倍上采样。

o u t p u t = ( i n p u t − 1 ) ∗ s t r i d e + o u t p u t p a d d i n g − 2 ∗ p a d d i n g + k e r n e l s i z e output = (input-1)*stride+outputpadding-2*padding+kernelsize output=(input1)stride+outputpadding2padding+kernelsize

具体执行过程为通过对原图插值0,扩大尺寸,然后改变卷积参数,对扩大尺寸后的进行卷积即nn.ConvTranspose2d:

  1. 原图插值,在两两元素之间插0;
  2. 改变参数。新的卷积核:Stride′=1, kernel的size′ = size padding’ 为Size−padding−1;
  3. 卷积。

![][pt_03]

1.4 OutConv模块

输出层中用1x1卷积实现:

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
 
    def forward(self, x):
        return self.conv(x)

2. 整体架构

# unet_model.py
import torch.nn.functional as F
from .unet_parts import *

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
 
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
 
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
 
if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

U-Net网络案例:分割细胞边缘

1. 数据加载

基于模板进行数据加载:  
# ================================================================== #
#                Input pipeline for custom dataset                 #
# ================================================================== #
 
# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0 
 
# You can then use the prebuilt data loader. 
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64, 
                                           shuffle=True)

使用上面的标准模板,进行加载数据、定义标签、数据增强等操作。

# dataset.py
import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random
 
class ISBI_Loader(Dataset):
    def __init__(self, data_path):
        # 初始化函数,读取所有data_path下的图片
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))
 
    def augment(self, image, flipCode):
        # 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
        flip = cv2.flip(image, flipCode)
        return flip
        
    def __getitem__(self, index):
        # 根据index读取图片
        image_path = self.imgs_path[index]
        # 根据image_path生成label_path
        label_path = image_path.replace('image', 'label')
        # 读取训练图片和标签图片
        image = cv2.imread(image_path)
        label = cv2.imread(label_path)
        # 将数据转为单通道的图片
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])
        # 处理标签,将像素值为255的改为1
        if label.max() > 1:
            label = label / 255
        # 随机进行数据增强,为2时不做处理
        flipCode = random.choice([-1, 0, 1, 2])
        if flipCode != 2:
            image = self.augment(image, flipCode)
            label = self.augment(label, flipCode)
        return image, label
 
    def __len__(self):
        # 返回训练集大小
        return len(self.imgs_path)
 
    
if __name__ == "__main__":
    isbi_dataset = ISBI_Loader("data/train/")
    print("数据个数:", len(isbi_dataset))
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=2, 
                                               shuffle=True)
    for image, label in train_loader:
        print(image.shape)

__init__函数是类的初始化函数,根据指定的图片路径data_path,读取所有图片数据,存放到self.imgs_path列表中。

__len__函数返回数据的数量,这个类实例化后,通过len()函数调用。

__getitem__函数是数据获取函数,函数里定义数据读取方式,处理方式,同时数据预处理、数据增强等也在该函数中进行定义。该案例中首先读取图片,并将其处理成单通道图片;将 label 的图片像素点的范围从[0, 255]归一化为[0, 1];最后随机进行了数据增强。

augment函数是定义数据增强函数,案例中进行的是旋转操作,除此之外还可以定义多种其他数据增强函数。

在这个类中,不用进行一些打乱数据集的操作,也不用管怎么按照 batchsize 读取数据。实例化这个类后,用 torch.utils.data.DataLoader 方法指定 batchsize 的大小,决定是否打乱数据。

2. 模型选择

这里使用的第二部分的U-Net网络结构,不再详细解释。

3. 算法选择

Loss函数的选择会对算法拟合数据效果产生较大的影响,分割细胞边缘是一个简单的二分类任务,所以选择使用BCEWithLogitsLoss。BCEWithLogitsLoss 是 Pytorch 提供的用来计算二分类交叉熵的函数:

l o s s ( o , t ) = − 1 / n ∑ i ( t [ i ] ∗ l o g ( o [ i ] ) + ( 1 − t [ i ] ) ∗ l o g ( 1 − o [ i ] ) ) loss(o, t) = -1/n\sum_i(t[i]*log(o[i]) + (1-t[i])*log(1-o[i])) loss(o,t)=1/ni(t[i]log(o[i])+(1t[i])log(1o[i]))

这是 Logistic 回归的损失函数,该函数利用 Sigmoid 函数阈值在[0,1]这个特性来进行分类。

优化目标的算法选择了一种基于AdaGrad改进的自适应的优化算法:RMSProp。优化算法本质上是梯度下降,例如:SGD(随机梯度下降算法)、Momentum(引入了动量的 SGD,以指数衰减的形式累计历史梯度)。而自适应参数的优化算法最大的特点是每个参数有不同的学习率,在整个学习过程中自动适应这些学习率,从而达到更好的收敛效果。

# train.py
from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch
 
def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    # 加载训练集
    isbi_dataset = ISBI_Loader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size, 
                                               shuffle=True)
    # 定义RMSprop算法
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 定义Loss算法
    criterion = nn.BCEWithLogitsLoss()
    # best_loss统计,初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    for epoch in range(epochs):
        # 训练模式
        net.train()
        # 按照batch_size开始训练
        for image, label in train_loader:
            optimizer.zero_grad()
            # 将数据拷贝到device中
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            # 使用网络参数,输出预测结果
            pred = net(image)
            # 计算loss
            loss = criterion(pred, label)
            print('Loss/train', loss.item())
            # 保存loss值最小的网络参数
            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'best_model.pth')
            # 更新参数
            loss.backward()
            optimizer.step()
 
if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道1,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 指定训练集地址,开始训练
    data_path = "data/train/"
    train_net(net, device, data_path)

4. 预测

import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet
 
if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load('best_model.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob('data/test/*.png')
    # 遍历所有图片
    for test_path in tests_path:
        # 保存结果地址
        save_res_path = test_path.split('.')[0] + '_res.png'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # 转为batch为1,通道为1,大小为512*512的数组
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        # 转为tensor
        img_tensor = torch.from_numpy(img)
        # 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        # 预测
        pred = net(img_tensor)
        # 提取结果
        pred = np.array(pred.data.cpu()[0])[0]
        # 处理结果
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        # 保存图片
        cv2.imwrite(save_res_path, pred)

输出结果:

![][pt_05]

整体目录:

├── data
│ ├── test
│ │ ├──……
│ │ └── *.png
│ └── train
│ ├──……
│ └── *.png
├── model
│ ├── unet_model.py
│ └── unet_parts.py
├── utils
│ └── dataset.py
├── train.py
└── predict.py

猜你喜欢

转载自blog.csdn.net/weixin_42992706/article/details/129251640