## 1 介绍

U-Net是一篇基本结构非常好的论文，主要是针对生物医学图片的分割，而且，在今后的许多对医学图像的分割网络中，很大一部分会采取U-Net作为网络的主干。相对于当年的，在EM segmentation challenge at ISBI 2012上做到比当时的best更好。而且速度也非常的快。其有一个很好的优点，就是在小数据集上也是能做得比较好的。就比如EM 2012这个数据集就只是30个果蝇第一龄幼虫腹侧神经所索的连续部分透射电子显微镜图。

## 2 源代码

（1）网络结构代码

``````import torch.nn as nn
import torch

class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)

def forward(self, input):
return self.conv(input)

class Unet(nn.Module):
def __init__(self,in_ch,out_ch):
super(Unet, self).__init__()

self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64,out_ch, 1)

def forward(self,x):
c1=self.conv1(x)
p1=self.pool1(c1)
c2=self.conv2(p1)
p2=self.pool2(c2)
c3=self.conv3(p2)
p3=self.pool3(c3)
c4=self.conv4(p3)
p4=self.pool4(c4)
c5=self.conv5(p4)
up_6= self.up6(c5)
merge6 = torch.cat([up_6, c4], dim=1)
c6=self.conv6(merge6)
up_7=self.up7(c6)
merge7 = torch.cat([up_7, c3], dim=1)
c7=self.conv7(merge7)
up_8=self.up8(c7)
merge8 = torch.cat([up_8, c2], dim=1)
c8=self.conv8(merge8)
up_9=self.up9(c8)
merge9=torch.cat([up_9,c1],dim=1)
c9=self.conv9(merge9)
c10=self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
``````

（2）数据集准备

``````import torch.utils.data as data
import PIL.Image as Image
import os

imgs = []
filename_data = [x for x in os.listdir(rootdata)]
for name in filename_data:
img = os.path.join(rootdata, name)
return imgs

class MyDataset(data.Dataset):
def __init__(self, rootdata, roottarget, transform=None, target_transform=None):
imgs = make_dataset(rootdata,roottarget)
#print(imgs)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform

def __getitem__(self, index):
x_path, y_path = self.imgs[index]
# print(x_path)
img_x = Image.open(x_path).convert('L')#读取并转换为二值图像
img_y = Image.open(y_path).convert('L')
if self.transform is not None:
img_x = self.transform(img_x)
if self.target_transform is not None:
img_y = self.target_transform(img_y)
# print(img_x.shape[0])
# print(img_x.shape[1])
# print(img_x.shape[2])
# print(img_x)
return img_x, img_y

def __len__(self):
return len(self.imgs)``````

（3）训练与测试

``````import numpy as np
import torch
import argparse
from torchvision import transforms
from unet import Unet
from dataset import MyDataset
from dataset import make_dataset
import os
import cv2
import time

# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

y_transforms = transforms.ToTensor()

def train_model(model, criterion, optimizer, dataload, num_epochs=150):
for epoch in range(0,num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
epoch_loss = 0
step = 0
step += 1
inputs = x.to(device)
labels = y.to(device)
# forward
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.6f" %
(step,(dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.6f" % (epoch, epoch_loss))
torch.save(model.cpu().state_dict(), 'weights_%d.pth' % epoch)
torch.save(model.cpu(),'weights_%d_dc.pth' % epoch)
return model

#训练模型
def train(train_data_path,train_gt_path):
batch_size = 1
# liver_dataset = MyDataset(
#     "image/train/data", "image/train/gt",transform=x_transforms, target_transform=y_transforms)
liver_dataset = MyDataset(
train_data_path, train_gt_path, transform=x_transforms, target_transform=y_transforms)
liver_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

#显示模型的输出结果
def test(test_data_path,test_gt_path,save_pre_path):
# liver_dataset = MyDataset(
#     "image/val/data", "image/val/gt", transform=x_transforms, target_transform=y_transforms)
liver_dataset = MyDataset(
test_data_path, test_gt_path, transform=x_transforms, target_transform=y_transforms)
imgs = make_dataset(test_data_path, test_gt_path)
# print(imgs[0][1])
print(liver_dataset)
import matplotlib.pyplot as plt
plt.ion()
count = 0
start = time.clock()
x = x.to('cuda')
y = model(x)
img_y = torch.squeeze(y).cpu().numpy()
elapsed = (time.clock() - start)
print("Time used:",elapsed)
plt.imsave(os.path.join(save_pre_path,os.path.basename(imgs[count][1])), img_y)
count+=1
plt.show()

def test_forDirsImages(source_data_path,source_gt_path,save_path):
if not os.path.exists(source_data_path):
return
if not os.path.exists(source_gt_path):
return
if not os.path.exists(save_path):
os.makedirs(save_path)
sour_data_path = source_data_path
sour_gt_path = source_gt_path
sav_path = save_path
for i in range(1,100):
source_data_path = os.path.join(source_data_path,str(i))
source_gt_path = os.path.join(source_gt_path,str(i))
save_path = os.path.join(save_path, str(i))
if not os.path.exists(save_path):
os.makedirs(save_path)
test(source_data_path, source_gt_path, save_path)

source_data_path = sour_data_path
source_gt_path = sour_gt_path
save_path = sav_path

if __name__ == '__main__':
pretrained = False
model = Unet(1, 1).to(device)
if pretrained:
criterion = torch.nn.BCELoss()