Pytorch Deep Learning Practical Tutorial (3): UNet model training, deep analysis!

Pytorch deep learning practical tutorial (3): UNet model training

This article  has been included on GitHub  https://github.com/Jack-Cherish/PythonPark . There are technical dry goods articles, organized learning materials, and first-line manufacturers' interview experience sharing. Welcome to Star and improve it.

I. Introduction

This article belongs to the Pytorch deep learning semantic segmentation tutorial series.

The content of this series of articles are:

  • Basic use of Pytorch
  • Explanation of semantic segmentation algorithm

PS: All the codes appearing in the article can be downloaded on my github, welcome to Follow, Star: click to view

2. Project background

Deep learning algorithms are nothing more than the way we solve a problem. What kind of network to choose for training, what kind of preprocessing, what loss and optimization method to use, are all determined according to the specific task.

So, let us first look at today's task.

Yes, it is the classic task in the UNet paper: medical image segmentation.

It was chosen as today's task because of its simplicity and ease of use.

Briefly describe this task: as shown in the animation, to give a cell structure diagram, we have to separate each cell from each other.

The training data is only 30 pieces with a resolution of 512x512. These pictures are electron micrographs of fruit flies.

Well, the task introduction is complete, and the training model begins.

 

Three, UNet training

To train a deep learning model, you can simply divide it into three steps:

  • Data loading: How to load data, how to define tags, and what data enhancement method to use are all performed in this step.
  • Model selection: We have prepared the model, which is the UNet network mentioned in the previous article in this series.
  • Algorithm selection: Algorithm selection is what loss we choose and what optimization algorithm to use.

Each step is more general, and we will expand on the explanation based on today's medical image segmentation task.

1. Data loading

In this step, many things can be done. To put it bluntly, it is nothing more than how to load pictures and how to define labels. In order to increase the robustness of the algorithm or increase the data set, some data enhancement operations can be done.

Since it is processing data, let us first look at what the data looks like before deciding how to deal with it.

The data is ready, all here (Github): click to view

If the Github download speed is slow, you can use the Baidu link at the end of the article to download the dataset.

The data is divided into training set and test set, each with 30 sheets, the training set has labels, and the test set has no labels.

The processing of data loading is determined according to the task and data set. For our segmentation task, we don’t need to do too much processing, but because the amount of data is very small, only 30 sheets, we can use some data enhancement methods to expand Our data set.

Pytorch provides us with a method to facilitate our loading of data. We can use this framework to load our data. Look at the pseudo code:

# ================================================================== #
#                Input pipeline for custom dataset                 #
# ================================================================== #

# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0 

# You can then use the prebuilt data loader. 
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64, 
                                           shuffle=True)

This is a standard template. We use this template to load data, define tags, and perform data enhancement.

Create a dataset.py file and write the code as follows:

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random

class ISBI_Loader(Dataset):
    def __init__(self, data_path):
        # 初始化函数,读取所有data_path下的图片
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))

    def augment(self, image, flipCode):
        # 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
        flip = cv2.flip(image, flipCode)
        return flip
        
    def __getitem__(self, index):
        # 根据index读取图片
        image_path = self.imgs_path[index]
        # 根据image_path生成label_path
        label_path = image_path.replace('image', 'label')
        # 读取训练图片和标签图片
        image = cv2.imread(image_path)
        label = cv2.imread(label_path)
        # 将数据转为单通道的图片
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])
        # 处理标签,将像素值为255的改为1
        if label.max() > 1:
            label = label / 255
        # 随机进行数据增强,为2时不做处理
        flipCode = random.choice([-1, 0, 1, 2])
        if flipCode != 2:
            image = self.augment(image, flipCode)
            label = self.augment(label, flipCode)
        return image, label

    def __len__(self):
        # 返回训练集大小
        return len(self.imgs_path)

    
if __name__ == "__main__":
    isbi_dataset = ISBI_Loader("data/train/")
    print("数据个数:", len(isbi_dataset))
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=2, 
                                               shuffle=True)
    for image, label in train_loader:
        print(image.shape)

Run the code, you can see the following results:

Explain the code:

The __init__ function is the initialization function of this class. According to the specified image path, all image data is read and stored in the self.imgs_path list.

The __len__ function can return how much data, after this class is instantiated, it is called through the len() function.

The __getitem__ function is a data acquisition function, in this function you can write how to read the data, how to process it, and some data preprocessing and data enhancement can be performed here. My processing here is very simple, just read the picture and process it into a single-channel picture. At the same time, because the image pixels of the label are 0 and 255, it needs to be divided by 255 to become 0 and 1. At the same time, data enhancement was performed randomly.

The augment function is a defined data augmentation function. It does not matter what you do. I just performed a simple rotation operation here.

In this class, you don't need to perform some operations to shuffle the data set, and you don't have to worry about how to read the data according to the batchsize. Because after instantiating this class, we can use the torch.utils.data.DataLoader method to specify the size of the batchsize to determine whether to disrupt the data.

The DataLoader provided to us by Pytorch is very powerful. We can even specify how many processes are used to load the data, whether the data is loaded into the CUDA memory or not for high-level usage. This article does not cover it, so we will not explain it again.

2. Model selection

We have already selected the model, and we will use the UNet network structure explained in the previous article " Pytorch Deep Learning Practical Tutorial (2): UNet Semantic Segmentation Network ".

But we need to fine-tune the network. According to the structure of the paper, the size of the model output will be slightly smaller than the size of the image input. If you use the network structure of the paper, you need to do a resize operation after the result is output. In order to save this step, we can modify the network so that the output size of the network is exactly equal to the input size of the picture.

Create the unet_parts.py file and write the following code:

""" Parts of the U-Net model """
"""https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

Create unet_model.py file and write the following code:

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""

import torch.nn.functional as F

from .unet_parts import *

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

After this adjustment, the output size of the network will be the same as the input size of the picture.

3. Algorithm selection

The choice of Loss is very important, and the choice of Loss will affect the effectiveness of the algorithm in fitting the data.

The choice of Loss is also determined by the task. Our task today only needs to segment the edge of the cell, which is a very simple binary classification task, so we can use BCEWithLogitsLoss.

What is BCEWithLogitsLoss? BCEWithLogitsLoss is a function provided by Pytorch to calculate the cross entropy of two categories.

Its formula is:

Friends who have seen my machine learning series of tutorials must be familiar with this formula, which is the loss function of Logistic regression. It uses the feature that the threshold of the Sigmoid function is [0,1] for classification.

For the specific formula derivation, you can see my machine learning series tutorial " Machine Learning Practical Tutorial (6): The Gradient Rising Algorithm in the Basics of Logistic Regression ", which will not be repeated here.

The objective function, that is, Loss is determined, how to optimize this objective?

The easiest way is to use our familiar gradient descent algorithm to gradually approach the local extremum.

But this kind of simple optimization algorithm is slow to solve, that is, it takes effort to find the optimal solution.

Various optimization algorithms are actually gradient descent in nature. For example, the most conventional SGD is an improved stochastic gradient descent algorithm based on gradient descent. Momentum introduces momentum SGD to accumulate historical gradients in the form of exponential decay.

In addition to these most basic optimization algorithms, there are also optimization algorithms for adaptive parameters. The biggest feature of this type of algorithm is that each parameter has a different learning rate, and these learning rates are automatically adapted during the entire learning process, so as to achieve better convergence effects.

This article chooses an adaptive optimization algorithm RMSProp.

Due to limited space, I will not expand it here. It is not enough to explain this optimization algorithm alone. To understand RMSProp, you must first know what AdaGrad is, because RMSProp is an improvement based on AdaGrad.

There are also more advanced optimization algorithms than RMSProp, such as the famous Adam, which can be seen as a revised Momentum+RMSProp algorithm.

In short, for beginners, you only need to know that RMSProp is an adaptive optimization algorithm, which is more advanced.

Next, we can start to write the code for training UNet, create train.py and write the following code:

from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch

def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    # 加载训练集
    isbi_dataset = ISBI_Loader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size, 
                                               shuffle=True)
    # 定义RMSprop算法
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 定义Loss算法
    criterion = nn.BCEWithLogitsLoss()
    # best_loss统计,初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    for epoch in range(epochs):
        # 训练模式
        net.train()
        # 按照batch_size开始训练
        for image, label in train_loader:
            optimizer.zero_grad()
            # 将数据拷贝到device中
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            # 使用网络参数,输出预测结果
            pred = net(image)
            # 计算loss
            loss = criterion(pred, label)
            print('Loss/train', loss.item())
            # 保存loss值最小的网络参数
            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'best_model.pth')
            # 更新参数
            loss.backward()
            optimizer.step()

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道1,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 指定训练集地址,开始训练
    data_path = "data/train/"
    train_net(net, device, data_path)

In order to make the project more clear and concise, we create a model folder, which contains the model-related code, which is our network structure code, unet_parts.py and unet_model.py.

Create a utils folder and place tool-related code, such as data loading tool dataset.py.

This modular management greatly improves the maintainability of the code.

Train.py can be placed in the root directory of the project, and the code will be explained briefly.

Since there are only 30 pieces of data, we do not divide the training set and the validation set. We save the network parameter with the lowest loss value of the training set as the best model parameter.

If there is no problem, you can see that the loss is gradually converging.

Four, forecast

After the model is trained, we can use it to see the effect on the test set.

Create a predict.py file in the project root directory and write the following code:

import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet

if __name__ == "__main__":
    # 选择设备,有cuda用cuda,没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络,图片单通道,分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load('best_model.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob('data/test/*.png')
    # 遍历所有图片
    for test_path in tests_path:
        # 保存结果地址
        save_res_path = test_path.split('.')[0] + '_res.png'
        # 读取图片
        img = cv2.imread(test_path)
        # 转为灰度图
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # 转为batch为1,通道为1,大小为512*512的数组
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        # 转为tensor
        img_tensor = torch.from_numpy(img)
        # 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        # 预测
        pred = net(img_tensor)
        # 提取结果
        pred = np.array(pred.data.cpu()[0])[0]
        # 处理结果
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        # 保存图片
        cv2.imwrite(save_res_path, pred)

After running, you can see the prediction results in the data/test directory:

Pytorch deep learning practical tutorial (3): UNet model training

That's it!

Five, finally

  • This article mainly explains the three steps of training a model: data loading, model selection, and algorithm selection.
  • This is a simple example. Training normal vision tasks is much more complicated. For example: when training the model, we need to choose which model to save according to the accuracy of the model on the validation set; we need to support tensorboard to facilitate our observation of loss convergence and so on.

Like it and then read it, develop a habit, search on WeChat official account【JackCui-AI】 Follow a stalker who is crawling on the Internet

 

Guess you like

Origin blog.csdn.net/c406495762/article/details/106349644