UNet network for image segmentation DRIVE dataset

Table of contents

1 Introduction

2. Build UNet network

3. dataset data loading

4. train training network

5. predict segmented image

6. show

7. Complete code


1 Introduction

The directory of the project is as follows

  1. DRIVE stores data sets
  2. predict is the image to be segmented
  3. Put the result of splitting predict in result
  4. dataset is the file for processing data, model stores the unet network, predict is the prediction, train is the training of the network, UNet.pth is the trained weight file

 

I made an example of image segmentation before, and most of the code in it overlaps with the content of this article, so the code of each script will only be briefly introduced. For details, please refer to the previous content, here is the link:

model :   UNet - unet network

dataset: UNet - data loading Dataset

train :  UNet - training data train

predict :  UNet - prediction data predict (segmentation of multiple images)

DRIVE (Digital Retinal Images for Vessel Extraction): Digital retinal images for vessel extraction

Training samples: grayscale images

 Corresponding label: binary image

Because this split project has been completed for a few weeks, it was only recently sorted out. Therefore, the original data set DRIVE may be a color image + mask mask (can't remember exactly)

  • No mask is used here 
  • If it is a color image, when generating the unet network, set the incoming channel to 3. Or if you want to use a grayscale image, you can either use opencv to convert it, and you can see that the grayscale effect is similar to that shown; or convert it into a grayscale image in the preprocessing transform.Grayscale()

2. Build UNet network

Different from the previous unet network, here by filling the size, the input of any image dimension can be guaranteed

The previous code needs to be down-sampled by 4 times, and the size is halved each time the dimension is expanded, so it is necessary to ensure that the size of the input image is 2 to the 4th power

 

I can’t understand how to implement this piece. After testing, any input size can be realized.

3. dataset data loading

When the data is loaded, the preprocessing of the image is also put here

The images trained here need to be ToTensor, normalized + change channel order + converted to tensor, etc. At the same time, in order to speed up the training, the image is normalized, because the training image is a grayscale image, so only the mean and standard deviation of a single channel are needed


Then the initialization of data loading

The content in the imgs here is the image path under the incoming path root, here is:

['01.png', '02.png', '03.png', '04.png', '05.png', '06.png', '07.png', '08.png', '09.png']

self.imgs is the path that concatenates the root path and the path of each image in the root, here is:

['./DRIVE/test/image\\01.png', './DRIVE/test/image\\02.png', './DRIVE/test/image\\03.png', './DRIVE/test/image\\04.png', './DRIVE/test/image\\05.png', './DRIVE/test/image\\06.png', './DRIVE/test/image\\07.png', './DRIVE/test/image\\08.png', './DRIVE/test/image\\09.png']

As shown in the picture:


 After initializing the path and preprocessing, the image needs to be processed

The training sample and the label file name of the corresponding binary image must be the same, otherwise other processing is required. For example, here you only need to replace the image in the image path of the training sample with label to find the corresponding segmented image

 

Then read the image, after preprocessing, just return it.

Here, in order to prevent that the label is not strictly a binary image, after normalization (gray value/255), the intermediate gray value is also mapped to the foreground pixel

4. train training network

The code for training the network has basically not changed, here is a brief introduction

Determine the device running on the network, and connect the network to the device

 

Load training set + test set

The training sample is passed in here , because Data_loader will replace the path of the sample with label to find the corresponding segmented label image

Because of insufficient memory, the batch size is set to 1 here

 

Then define the optimizer + loss function, and save the training weight file of the network

For BCEWithLogitsLoss, you can refer to this: Talk about the loss function of image segmentation - BCEWithLogitsLoss

 

When training, the network needs to be in train mode, and then the content of correct forward propagation prediction + reverse gradient descent

 

Finally, to calculate the correct rate, you need to put the network in eval mode

Here, the prediction of the network is converted into a binary image, and then the accuracy rate is calculated by comparing the predicted binary image with the label pixel by pixel, and finally comparing the spatial resolution of the entire image, that is, the size of the image.

The channel order of test_label is: batch, channel, height, width

 

5. predict segmented image

The preprocessing here should be consistent with the preprocessing of the processed samples

 

Load network + read network parameters

 

When predicting, you need to expand the dimension. When saving the image, you need to subtract the batch and channel

Then convert the predicted result into a binary image.

 

 

6. show

After training for 20 epochs, the results are as follows

 

The image to be predicted here is in the test data set, and the image in the predict is:

Results of UNet segmentation:

 

The real label is:

Most of the information is segmented, but there are still details that have not been segmented

The size of the image is 565*584, and the prediction accuracy is about 0.96

In other words, there are still 565*584*0.04 = 13198, these lost pixels are the missing details

7. Complete code

model part:

import torch.nn as nn
import torch
import torch.nn.functional as F


# 搭建unet 网络
class DoubleConv(nn.Module):    # 连续两次卷积
    def __init__(self,in_channels,out_channels):
        super(DoubleConv,self).__init__()
        self.double_conv = nn.Sequential(

            nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,bias=False),
            nn.BatchNorm2d(out_channels),                           # 用 BN 代替 Dropout
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.double_conv(x)
        return x


class Down(nn.Module):   # 下采样
    def __init__(self,in_channels,out_channels):
        super(Down, self).__init__()
        self.downsampling = nn.Sequential(
            nn.MaxPool2d(kernel_size=2,stride=2),
            DoubleConv(in_channels,out_channels)
        )

    def forward(self,x):
        x = self.downsampling(x)
        return x


class Up(nn.Module):    # 上采样
    def __init__(self, in_channels, out_channels):
        super(Up,self).__init__()

        self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) # 转置卷积
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.upsampling(x1)

        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])         # 确保任意size的图像输入
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)  # 从channel 通道拼接
        x = self.conv(x)
        return x


class OutConv(nn.Module):   # 最后一个网络的输出
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):   # unet 网络
    def __init__(self, in_channels = 1, num_classes = 1):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        self.in_conv = DoubleConv(in_channels, 64)

        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)

        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)

        self.out_conv = OutConv(64, num_classes)

    def forward(self, x):

        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.out_conv(x)

        return x

dataset data processing part:

import os
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms


data_transform = {
    "train": transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, ), (0.5, ))]),
    "test": transforms.Compose([transforms.ToTensor()])
}


# 数据处理文件
class Data_Loader(Dataset):     # 加载数据
    def __init__(self, root, transforms_train=data_transform['train'],transforms_test=data_transform['test']):    # 初始化
        imgs = os.listdir(root)                                                         # 读取图像的路径
        self.imgs = [os.path.join(root,img) for img in imgs]                            # 取出路径下所有的图片
        self.transforms_train = transforms_train                                        # 预处理
        self.transforms_test = transforms_test

    def __getitem__(self, index):                      # 获取数据、预处理等等
        image_path = self.imgs[index]                  # 根据index读取图片
        label_path = image_path.replace('image', 'label')   # 根据image_path生成label_path

        image = Image.open(image_path)                      # 读取图片和对应的label图
        label = Image.open(label_path)

        image = self.transforms_train(image)        # 样本预处理

        label = self.transforms_test(label)         # label 预处理
        label[label > 0] = 1

        return image, label

    def __len__(self):  # 返回样本的数量
        return len(self.imgs)

train network training part:

from model import UNet
from dataset import Data_Loader
from torch import optim
import torch.nn as nn
import torch

# 网络训练模块
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # GPU or CPU
print(device)
net = UNet(in_channels=1, num_classes=1)        # 加载网络
net.to(device)                                  # 将网络加载到device上

# 加载训练集
trainset = Data_Loader("./DRIVE/train/image")
train_loader = torch.utils.data.DataLoader(dataset=trainset,batch_size=1,shuffle=True)
len = len(trainset)                         # 样本总数为 31

# 加载测试集
testset = Data_Loader("./DRIVE/test/image")
test_loader = torch.utils.data.DataLoader(dataset=testset,batch_size=1)

# 加载优化器和损失函数
optimizer = optim.RMSprop(net.parameters(), lr=0.00001,weight_decay=1e-8, momentum=0.9)     # 定义优化器
criterion = nn.BCEWithLogitsLoss()                             # 定义损失函数

# 保存网络参数
save_path = './UNet.pth'       # 网络参数的保存路径
best_acc = 0.0                 # 保存最好的准确率

# 训练
for epoch in range(20):

    net.train()     # 训练模式
    running_loss = 0.0

    for image,label in train_loader:

        optimizer.zero_grad()                          # 梯度清零
        pred = net(image.to(device))                   # 前向传播
        loss = criterion(pred, label.to(device))       # 计算损失
        loss.backward()                                # 反向传播
        optimizer.step()                               # 梯度下降

        running_loss += loss.item()                    # 计算损失和

    net.eval()  # 测试模式
    acc = 0.0   # 正确率
    total = 0
    with torch.no_grad():
        for test_image, test_label in test_loader:

            outputs = net(test_image.to(device))     # 前向传播

            outputs[outputs >= 0] = 1  # 将预测图片转为二值图片
            outputs[outputs < 0] = 0

            # 计算预测图片与真实图片像素点一致的精度:acc = 相同的 / 总个数
            acc += (outputs == test_label.to(device)).sum().item() / (test_label.size(2) * test_label.size(3))
            total += test_label.size(0)

    accurate = acc / total  # 计算整个test上面的正确率
    print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f %%' %
          (epoch + 1, running_loss/len, accurate*100))

    if accurate > best_acc:     # 保留最好的精度
        best_acc = accurate
        torch.save(net.state_dict(), save_path)     # 保存网络参数

predict prediction part:

import numpy as np
import torch
import cv2
from model import UNet
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,),(0.5))
    ])


# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(in_channels=1, num_classes=1)
net.load_state_dict(torch.load('UNet.pth', map_location=device))
net.to(device)

# 测试模式
net.eval()
with torch.no_grad():

    img = Image.open('./predict/img.png')           # 读取预测的图片
    img = transform(img)                            # 预处理
    img = torch.unsqueeze(img,dim = 0)              # 增加batch维度

    pred = net(img.to(device))                      # 网络预测

    pred = torch.squeeze(pred)                      # 将(batch、channel)维度去掉
    pred = np.array(pred.data.cpu())                # 保存图片需要转为cpu处理

    pred[pred >=0 ] =255                            # 转为二值图片
    pred[pred < 0 ] =0

    pred = np.uint8(pred)                           # 转为图片的形式
    cv2.imwrite('./result/res.png', pred)           # 保存图片

Guess you like

Origin blog.csdn.net/qq_44886601/article/details/128188184