[PytorchLearning] Case Nanny Tutorial for Semantic Segmentation of Lung Image Based on UNet

Lung Image Segmentation Based on UNet

Generally speaking, the field of computer vision includes three mainstream tasks: classification, detection, and segmentation. Among them, the classification task has relatively simple requirements for the model. It has been introduced in detail in the previous Pytorch introductory tutorial. Interested friends can check the previous blog; while the detection and segmentation tasks require multi-scale high-level features. information, so the requirements for the structure of the model are also slightly more complicated. In this article, I will mainly introduce the whole process of semantic segmentation tasks based on lung CT images and UNet network. Without further ado, let's get to the point.

1 What is Semantic Segmentation?

Semantic segmentation is the core technology in the field of computer vision. By classifying each pixel in the image, the image is divided into several regions with specific semantic categories. In layman's terms, the target detection task is to locate and classify the foreground (various target objects) in the image, and detect instance objects such as cats, cats, dogs, and people, while the semantic segmentation task requires the network to detect the foreground objects in the image. The category of each pixel is judged, and the pixel-level precise segmentation is performed, which is widely used in the field of automatic driving.

2 Lung CT image segmentation case

For beginners in semantic segmentation, lung image segmentation is indeed a project that is relatively easy to understand and not too difficult to get started. Below I mainly introduce the case from three aspects: data, model, result and prediction.

2.1 Dataset production

2.1.1 Dataset overview

This semantic segmentation mainly uses 2D images, including CT images and label images, both of which are single-channel images with a resolution of 512x512 and 267 images each. The data is displayed as follows:

2.1.2 Data preprocessing

Since the background in the label image is represented by 0, the lung image is represented by 255, but when using pytorch classification, the categories need to be represented in order from 0 (the category needs to be a continuous tensor starting from Lei, which has been mentioned before ). Therefore, we need to change the 255 value of the table lung image to 1. The relevant main code is shown in code list 1.

# 代码清单1
# 介绍:读入原始2D图像数据,对像素标签进行映射:0=>0  255=>1
        image = cv.imread(image_fullpath,0)
        img_array = np.asarray(image)
        for i in img_array:
            for j in i:
                if j == 255:
                   label_img.append(1)
                else:
                   label_img.append(0)
        output_img = op_dir + each_image
        label_img = np.array(label_img)
        label_img = label_img.reshape((512, 512))
        cv.imwrite(output_img, label_img)
        n = n + 1
        print("处理完成label: %d" % n)

After the pixel value mapping is completed, the image size is standardized through the resize function, and finally a 512*512 image containing only 0 and 1 pixel values ​​is obtained. Since the brightness represented by 1 is very low, the processed label image appears completely black to the naked eye, and the processed image label is shown in the figure.

Some students here may ask, why can’t the CT images be seen after processing? Is there a problem? In fact, this is not the case, because we need to map pixel values ​​​​to 0 and 1, so the pixels of the label image are only composed of 0 and 1. For the human eye, it is difficult to distinguish this subtle pixel value difference, unless your eyes are electronic eyes. ...If you are not at ease, you can actually select a few label pictures at random and read them in with opencv or PIL, and print out the pixel values ​​of the pictures to check the results.

2.1.3 Generate data path

In order to read the picture conveniently, we need to generate three txt files to record the path of the original image and its corresponding label image (the relevant image processing basis has been mentioned in the previous blog, and those who have doubts can check it out by themselves). The image generation path and the code list of the corresponding label are as follows:

# 代码清单2
# 介绍:读入原始2D图像数据,生成路径及标签
import os

def walk_dir(dir):
    dir_list=[]
    for image in os.listdir(dir):
        dir_list.append(os.path.join(dir,image))
    return dir_list

original_dir=r'CT_image'
save_dir=r'CT_txt'
if not save_dir:
    os.mkdir(save_dir)

img_dir=os.listdir(original_dir)
img_test=walk_dir(os.path.join(original_dir,img_dir[0]))
img_test_label=walk_dir(os.path.join(original_dir,img_dir[1]))
img_t_v=walk_dir(os.path.join(original_dir,img_dir[2]))
img_t_v_label=walk_dir(os.path.join(original_dir,img_dir[3]))
img_train=img_t_v[:188]
img_val=img_t_v[188:]
img_train_label=img_t_v_label[:188]
img_val_label=img_t_v_label[188:]

# 查看每个图片与标签是否对应
# sum=0
# for index in range(len(img_train)):
#     train=img_train[index].split("\\")[-1]
#     train_label=img_train_label[index].split("\\")[-1]
#     if train==train_label:
#         print(train," ",train_label)
#         sum+=1
# print(sum)

# 将训练集写入train.txt
with open(os.path.join(save_dir, 'train.txt'), 'a')as f:
    for index in range(len(img_train)):
        f.write(img_train[index]+'\t' +img_train_label[index]+'\n')
    print("训练集及标签写入完毕")
# 将验证集写入val.txt
with open(os.path.join(save_dir, 'val.txt'), 'a')as f:
    for index in range(len(img_val)):
        f.write(img_val[index] + '\t' +img_val_label[index]  + '\n')
    print("验证集及标签写入完毕")
# 测试集
with open(os.path.join(save_dir, 'test.txt'), 'a')as f:
    for index in range(len(img_test)):
        f.write(img_test[index] + '\t' +img_test_label[index]+ '\n')

After running, three text documents of train.txt, val.txt and test.txt are obtained. train and val are used to train and verify the model, including the data path and labels; test is used to test the model, including only the data path.

2.1.4 Define Dataset

In Pytorch, the network can handle tensors, so we need to convert the read pictures into tensor data and input them into the network. Here we use a very important library: torch.utils.Dataset library.
Dataset is a wrapper class, which is used to wrap data into a Dataset class, and then pass it into DataLoader. We then use the DataLoader class to operate on the data more quickly. To inherit the Dataset class, the __len__ method and __getitem__ method must be rewritten. __len__ returns the length of the dataset, and __getitem__ can obtain data by index. Its implementation is shown in Listing 3 .

# 代码清单3
# 介绍:将读取到的图像数据转化为张量
import torch
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset

def read_txt(path):
    # 读取文件
    ims, labels = [], []
    with open(path, 'r') as f:
        for line in f.readlines():
            im, label = line.strip().split("\t")
            ims.append(im)
            labels.append(label)
    return ims, labels

class UnetDataset(Dataset):
    def __init__(self, txtpath, transform):
        super().__init__()
        self.ims, self.labels = read_txt(txtpath)
        self.transform = transform
    def __getitem__(self, index):
        im_path = self.ims[index]
        label_path = self.labels[index]
        image = Image.open(im_path)
        image = self.transform(image).float().cuda()
        label = torch.from_numpy(np.asarray(Image.open(label_path), dtype=np.int32)).long().cuda()
        return image, label

    def __len__(self):
        return len(self.ims)

2.2 Overview of network structure

The UNet network structure is similar to a big U letter: first perform Conv+Pooling downsampling; then Deconv deconvolution for upsampling, crop the low-level feature map before fusion, and then upsample again. Repeat this process until the feature map that outputs 388 388 2 is obtained, and finally the output segment map is obtained through softmax. Unlike the point-by-point addition of FCN, U-Net uses features to be stitched together in the channel dimension to form deeper features. Details of the specific network structure will not be repeated here.

2.3 Results and predictions

2.3.1 Transformation of prediction results

The prediction function is in the sharing link at the end of the article. The code is too long and will not be displayed. I will mainly talk about the transformation of the results. In the process of data production, we map the data with a pixel value of 255 to 1, and the prediction of the network is also 0 and 1, so we need to convert 1 into a pixel value of 255 for the result output. The process is as follows.

# 代码清单4
# 介绍:将预测结果转化为实际黑白影像
def translabeltovisual(save_label, path):
    visual_img = []
    im = cv2.imread(save_label, 0)
    img_array = np.asarray(im)
    for i in img_array:
        for j in i:
            if j == 1:
                visual_img.append(255)
            else:
                visual_img.append(0)
    visual_img = np.array(visual_img)
    visual_img = visual_img.reshape((Height, Width))
    cv2.imwrite(path, visual_img)

2.3.2 Result display

2.3.3 Partial explanation of model evaluation function

2.3.4

Use Tensorboard to record the loss, accuracy, and IOU during the training process, with the training rounds as the horizontal axis and each indicator as the vertical axis. The respective curves are shown in the figure.

The following conclusions can be drawn from the curve during the training process:
First, the overall loss of the model in the training process on the training data decreases slowly without any oscillations; however, the oscillations are obvious on the verification set, indicating that the initial training parameters are not appropriate.
Second, the accuracy in semantic segmentation cannot fully represent the performance of the algorithm, and the real performance depends more on the average intersection and union score. Whether it is the training set or the verification set, the gap between the accuracy and the IoU score is about 20 percentage points, so the semantic segmentation task should not only pay attention to the accuracy, but also pay attention to the IoU score of the model.

3 Source code and data sharing

Follow the WeChat public account "Alchemy Little Genius" and reply to CT to get image data and source code!
insert image description here


OVER

Guess you like

Origin blog.csdn.net/weixin_43427721/article/details/125255837