[3D Image Segmentation] 3D Image Segmentation 4 based on Pytorch (Rewriting Data Flow)

In this previous article:[3D Image Segmentation] VNet 3D Image Segmentation 2 based on Pytorch (Basic Data Flow) At the end, we mentioned the following problem encountered during the training phase:

In the segmentation training task of data using the vent model, the input size is , and this is cropped It is placed in the class and cut out and . However, several problems were discovered during training:3d16*96*96Datasetimagemask

  1. It took a long time to load the data. It took 30 minutes from the start of training to the official printing and batch cycle.
  2. batch=64, torch.utils.data.DataLoaderInsidenum_workers=8, when training always reaches a multiple of 8, you have to wait for a long time
  3. When 4 GPUs are trained in parallel, the GPU utilization is 0 for a long time, occasionally rises, and then reaches 0 again in an instant.
  4. free -mCheck the memory usage of and find that buff and cache will gradually increase and gradually approach full capacity.

When this happens, what is the problem? The model is trained normally and converges well, but it is too slow. AnalyzingmyDatasetthe data reading code, there are several places that may be more time-consuming and occupy memory:

  1. getAnnotations function needs to obtain the file name and the corresponding coordinates of the nodule from the csv file, and finally store it as a dictionary, which always takes up memory space;
  2. getNpyFile_Path functions, dataFile_paths and labelFile_paths both need to be called, some are repeated, and the occupancy of this part can be doubled;
  3. get_annos_labelFunctions have the same problem, some are repeated, and the occupancy of this part can be doubled.

The above functions are all completed in the class__init__ stage. This kind of multiple cycles may be at the beginningbatch This part of the time before the loop is the main reason for time-consuming; secondly, due to repeated occupation of memory, performance degradation is further aggravated, making subsequent training slower.

In order to solve the above problems, the data loading version of this article was produced. The biggest change is to obtain the original data from the csv file. The form of nodule coordinates is changed to be obtained from the npy file. In this way, are all single files with one-to-one correspondence. From the subsequent actual training, we found that this is indeed the case. This time-consuming problem was solved and the training became faster. 2.0 Datasetimage、mask、Bbox

So, as long as we simplify the determined value and reduce the memory usage in the__init__ stage, this problem should be perfectly solved. Therefore, this article follows this principle and tries to discard as much as possible in the data preprocessing stage, leaving only the simplest one-to-one structure. Put preprocessing in front to avoid calling it during the data construction phase.

LUNA16For data preprocessing, you can refer to here. The data generated in this article is as follows:

1. Set up a data flow framework

Inpytorch, the data flow for training follows the following structure. The main idea is as follows:

  1. In__init__, it is executed during the class initialization phase. Here you need to determine a certain value to obtain all the content needed for training, but occupy as little content and time as possible;
  2. In__getitem__, an image and label information will be obtained according to the value determined by __init__, and operations such as reading and enhancement will be performed. Finally return the Tensor value;
  3. __len__returns the length of aepoch training determined value.

The following is a simple framework structure, which is reserved for reference and can be supplemented here in subsequent construction of data flows.

class myDataset_v3(Dataset):
    def __init__(self, data_dir, isTrain=True):
        self.data = []

        if isTrain:
        	self.data  ···
        else:
        	self.data  ···

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # ********** get file dir **********
        image, label = self.data[index]  # get whole data for one subject

        # ********** change data type from numpy to torch.Tensor **********
        image = torch.from_numpy(image).float()  
        label = torch.from_numpy(label).float()  
        return image, label

In this article, the parameters in this class are introduced in detail. If you are interested, you can go directly to learn:[BraTS] Brain Tumor Segmentation Brain Tumor Split 3 (Building Data Flow)

2. Improve the content of the framework

I believe that through the introduction of the previous four blogs6、7、8、9, you have processed the original data set ofLuna16 into one-to-one correspondence. We train Required data formats include:

  1. _bboxes.npy: Record the coordinates and radius of the nodule center point;
  2. _clean.nrrd:CT original image array;
  3. _mask.nrrd: Label file mask array, the same as _clean.nrrd’s shape;

also includes some others.npy, which record some quantities during the entire transformation phase and are not used in the training phase, so they will not be expanded upon here. The most concerning ones are the above three files, and they correspond one to one according toseriesUID.

If this is the case, we construct the myDataset_v3(Dataset) data volume and think about: In the __init__ stage, which one can be used as the anchor point, as few as possible In the case of occupying memory, can the required images and annotation information be obtained sequentially in the __getitem__ stage?

That is the file name of seriesUID. It can be dragged one to three, and a list is enough. This is the most memory-saving way. So our definition in the__init__ stage is as follows:

class myDataset_v3(Dataset):
    def __init__(self, data_dir, crop_size=(16, 96, 96), isTrain=False):
        self.bboxesFile_path = []
        for file in os.listdir(data_dir):
            if '_bboxes.npy' in file:
                self.bboxesFile_path.append(os.path.join(data_dir, file))

        self.crop_size = crop_size
        self.crop_size_z, self.crop_size_h, self.crop_size_w = crop_size
        self.isTrain = isTrain

Then the definition of __len__ is naturally known, as follows:

    def __len__(self):
        return len(self.bboxesFile_path)

The most important and difficult thing is the definition of__getitem__. Here you need to do a few things:

  1. Get the path of each file;
  2. Get the data corresponding to the file;
  3. Crop out the targetpatch;
  4. Several combinationsTensor.

Then, in the definition__getitem__, a problem was discovered, as follows:

    def __getitem__(self, index):
        bbox_path = self.bboxesFile_path[index]
        img_path = bbox_path.replace('_bboxes.npy', '_clean.nrrd')
        label_path = bbox_path.replace('_bboxes.npy', '_mask.nrrd')

        img, img_shape = self.load_img(img_path)
        label 		   = self.load_mask(label_path)
        zyx_centerCoor = self.getBboxes(bbox_path)

    def getBboxes(self, bboxFile_path):
        bboxes_array = np.load(bboxFile_path, allow_pickle=True)
        bboxes_list = bboxes_array.tolist()

        xyz_list = [[zyx[0], zyx[2], zyx[1]] for zyx in bboxes_list]

        return random.choice(xyz_list)

Mainly because the nodule coordinate point recorded by one_bboxes.npy is not just one nodule. If you put the obtained bbox into __getitem__, you will find that it can only cut out one patch at a time, and it is impossible to cut out many Every nodule situation is dealt with. So I used the random.choice method here to randomly select a nodule.

However, this method is not good because it will reduce the number of nodules appearing in the learning process. Although it is random, it is equivalent to reducing the amount of certain types of data. Under the same number of learningepoch, those with only one nodule will be learned relatively more times.

In order to solve this problem, the number of nodules is directly matched with the file name, so that the opportunity for each nodule is equal. The code looks like this:

import os
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
import nrrd
import cv2

class myDataset_v3(Dataset):
    def __init__(self, data_dir, crop_size=(16, 96, 96), isTrain=False):
        self.dataFile_path_bboxes = []
        for file in os.listdir(data_dir):
            if '_bboxes.npy' in file:
                one_path_bbox_list = self.getBboxes(os.path.join(data_dir, file))
                self.dataFile_path_bboxes.extend(one_path_bbox_list)

        self.crop_size = crop_size
        self.crop_size_z, self.crop_size_h, self.crop_size_w = crop_size
        self.isTrain = isTrain

    def __getitem__(self, index):
        bbox_path, zyx_centerCoor = self.dataFile_path_bboxes[index]

        img_path = bbox_path.replace('_bboxes.npy', '_clean.nrrd')
        label_path = bbox_path.replace('_bboxes.npy', '_mask.nrrd')

        img, img_shape = self.load_img(img_path)
        # print('img_shape:', img_shape)
        label = self.load_mask(label_path)

        # print('zyx_centerCoor:', zyx_centerCoor)

        cutMin_list = self.getCenterScope(img_shape, zyx_centerCoor)

        if self.isTrain:
            rd = random.random()
            if rd > 0.5:
                cut_list = [cutMin_list[0], cutMin_list[0]+self.crop_size_z, cutMin_list[1], cutMin_list[1]+self.crop_size_h, cutMin_list[2], cutMin_list[2]+self.crop_size_w]  ###  z,y,x
                start1, start2, start3 = self.random_crop_around_nodule(img_shape, cut_list, crop_size=self.crop_size, leftTop_ratio=0.3)
            elif rd > 0.1:
               start1, start2, start3 = self.random_crop_negative_nodule(img_shape, crop_size=self.crop_size)
            else:
                start1, start2, start3 = cutMin_list
        else:
            start1, start2, start3 = cutMin_list

        img_crop = img[start1:start1 + self.crop_size_z, start2:start2 + self.crop_size_h,
                   start3:start3 + self.crop_size_w]
        label_crop = label[start1:start1 + self.crop_size_z, start2:start2 + self.crop_size_h,
                     start3:start3 + self.crop_size_w]

        # print('before:', img_crop.shape, label_crop.shape)
        # 计算需要pad的大小
        if img_crop.shape != self.crop_size:
            pad_width = [(0, self.crop_size_z-img_crop.shape[0]), (0, self.crop_size_h-img_crop.shape[1]), (0, self.crop_size_w-img_crop.shape[2])]
            img_crop = np.pad(img_crop, pad_width, mode='constant', constant_values=0)
        if label_crop.shape != self.crop_size:
            pad_width = [(0, self.crop_size_z-label_crop.shape[0]), (0, self.crop_size_h-label_crop.shape[1]), (0, self.crop_size_w-label_crop.shape[2])]
            label_crop = np.pad(label_crop, pad_width, mode='constant', constant_values=0)

        # print('after:', img_crop.shape, label_crop.shape)
        img_crop = np.expand_dims(img_crop, 0)  # (1, 16, 96, 96)
        img_crop = torch.from_numpy(img_crop).float()

        label_crop = torch.from_numpy(label_crop).long()  # (16, 96, 96) label不用升通道维度
        return img_crop, label_crop

    def __len__(self):
        return len(self.dataFile_path_bboxes)

    def load_img(self, path_to_img):
        if path_to_img.startswith('LKDS'):
            img = np.load(path_to_img)
        else:
            img, _ = nrrd.read(path_to_img)
        img = img.transpose((0, 2, 1))      # 与xyz坐标变换对应
        return img/255.0, img.shape


    def load_mask(self, path_to_mask):
        mask, _ = nrrd.read(path_to_mask)
        mask[mask>1] = 1
        mask = mask.transpose((0, 2, 1))    # 与xyz坐标变换对应
        return mask

    def getBboxes(self, bboxFile_path):
        bboxes_array = np.load(bboxFile_path, allow_pickle=True)
        bboxes_list = bboxes_array.tolist()
        one_path_bbox_list = []
        for zyx in bboxes_list:
            xyz = [zyx[0], zyx[2], zyx[1]]
            one_path_bbox_list.append([bboxFile_path, xyz])

        return one_path_bbox_list

    def getCenterScope0(self, img_shape, zyx_centerCoor):
        cut_list = []  # 切割需要用的数
        for i in range(len(img_shape)):  # 0, 1, 2   →  z,y,x
            if i == 0:  # z
                a = zyx_centerCoor[-i - 1] - self.crop_size_z/2  # z
                b = zyx_centerCoor[-i - 1] + self.crop_size_z/2  # y,z
            else:  # y, x
                a = zyx_centerCoor[-i - 1] - self.crop_size_w/2
                b = zyx_centerCoor[-i - 1] + self.crop_size_w/2

            # 超出图像边界 1
            if a < 0:
                a = self.crop_size_z
                b = self.crop_size_w
            # 超出边界 2
            elif b > img_shape[i]:
                if i == 0:
                    a = img_shape[i] - self.crop_size_z
                    b = img_shape[i]
                else:
                    a = img_shape[i] - self.crop_size_w
                    b = img_shape[i]
            else:
                pass

            cut_list.append(int(a))
            cut_list.append(int(b))

        return cut_list

    def getCenterScope(self, img_shape, zyx_centerCoor):
        img_z, img_y, img_x = img_shape
        zc, yc, xc = zyx_centerCoor

        zmin = max(0, zc - self.crop_size_z // 3)
        ymin = max(0, yc - self.crop_size_h // 2)
        xmin = max(0, xc - self.crop_size_w // 2)

        cutMin_list = [int(zmin), int(ymin), int(xmin)]

        return cutMin_list

    def random_crop_around_nodule(self, img_shape, cut_list, crop_size=(16, 96, 96), leftTop_ratio=0.3):
        """
        :param img:
        :param label:
        :param center:
        :param radius:
        :param cut_list:
        :param crop_size:
        :param leftTop_ratio: 越大,阴性样本越多(需要考虑crop_size)
        :return:
        """
        img_z, img_y, img_x = img_shape
        crop_z, crop_y, crop_x = crop_size
        z_min, z_max, y_min, y_max, x_min, x_max = cut_list
        # print('z_min, z_max, y_min, y_max, x_min, x_max:', z_min, z_max, y_min, y_max, x_min, x_max)

        z_min = max(0, int(z_min-crop_z*leftTop_ratio))
        z_max = min(img_z, int(z_min + crop_z*leftTop_ratio))
        y_min = max(0, int(y_min-crop_y*leftTop_ratio))
        y_max = min(img_y, int(y_min+crop_y*leftTop_ratio))
        x_min = max(0, int(x_min-crop_x*leftTop_ratio))
        x_max = min(img_x, int(x_min+crop_x*leftTop_ratio))

        z_start = random.randint(z_min, z_max)
        y_start = random.randint(y_min, y_max)
        x_start = random.randint(x_min, x_max)

        return z_start, y_start, x_start

    def random_crop_negative_nodule(self, img_shape, crop_size=(16, 96, 96), boundary_ratio=0.5):
        img_z, img_y, img_x = img_shape
        crop_z, crop_y, crop_x = crop_size

        z_min = 0#crop_z*boundary_ratio
        z_max = img_z-crop_z#img_z - crop_z*boundary_ratio
        y_min = 0#crop_y*boundary_ratio
        y_max = img_y-crop_y#img_y - crop_y*boundary_ratio
        x_min = 0#crop_x*boundary_ratio
        x_max = img_x-crop_x#img_x - crop_x*boundary_ratio

        z_start = random.randint(z_min, z_max)
        y_start = random.randint(y_min, y_max)
        x_start = random.randint(x_min, x_max)

        return z_start, y_start, x_start

The above is the complete code of the new data flow after this rewrite, without adding data enhancement operations. During training, three types of diversity are introduced:

  1. Ensure that if mask has a nodule target, randomly change the position of the nodule in patch;
  2. The entire image is randomly cropped, mainly to generate negative samples;
  3. Cut directly using the nodule as the center point.

The purpose of this is actually to consider the location of the nodule in the patch, which may affect the final prediction. Because in the final inference stage we use, we actually don’t know where the nodule is in the image. We can only traverse all the patches, and then splice the predicted results into a complete mask, and then process the mask. location of all nodules.

This requires that no matter where the nodule appears in the image, it needs to be found with as few false positives as possible.

This is something that I rarely see covered in papers. I don’t know if the paper is only about indicators and forgets about the additional product of false positives. Also, the way to obtain these patches is to cut them out in advance and directly read the patch array for training. This kind of thing is not good either. It’s not diverse enough and it’s quite troublesome.

What we will also talk about in this section are the two functions getCenterScope and random_crop_around_nodule. Why is divisible in getCenterScope? This is because I checked it many times and summarized it. If it is an integer division, it will be found that all the nodules are downward. The reason for this has not been understood yet. If you know, please leave a message. 32

If it is a two-dimensional plane and the center point is known, then find the minimum value of the upper left corner, which should be the coordinates of the center point, minus half the width and height. However, when the z axis is also used minus one-half, it is found that all the cropped nodules are very low.

2

So, here we subtract one-third to move it up a little on the z-axis. I still haven’t figured out the question here, so if you know, please give me some advice in the comments section.

random_crop_around_noduleIt controls the coordinates of the minimum and maximum values ​​of the upper left corner of the clipping, and is randomly determined within this interval, thus making the clipping of nodules more diverse. As shown below:

I just want the nodules to appear in every cut, and I only need the coordinates of the upper left corner of the nodules to fall within a certain range. leftTop_ratioThe parameter is used to control the distance from the upper left corner point to the upper left corner.

This value needs to be determined by yourself based on the size ofpatch. It is important to check it multiple times.

3. Verify data flow

Constructing a class function with a large amount of data is not finished yet. Because you don't know whether the data flow at this time meets your requirements. So it would be great if we could simulate the training process and see the results of eachpatch in advance.

This chapter is for this purpose. Let’s type out the images and masks to see if there are any problems. The viewing method is also relatively simple, and you can copy it and use it in your own projects later.

def getContours(output):
    img_seged = output.numpy().astype(np.uint8)
    img_seged = img_seged * 255

    # ---- Predict bounding box results with txt ----
    kernel = np.ones((5, 5), np.uint8)
    img_seged = cv2.dilate(img_seged, kernel=kernel)
    _, img_seged_p = cv2.threshold(img_seged, 127, 255, cv2.THRESH_BINARY)
    try:
        _, contours, _ = cv2.findContours(np.uint8(img_seged_p), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    except:
        contours, _ = cv2.findContours(np.uint8(img_seged_p), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    return contours

if __name__=='__main__':
    data_dir = r"./valid"

    dataset_valid = myDataset_v3(data_dir,  crop_size=(48, 96, 96), isTrain=False)  # 送入dataset
    valid_loader = torch.utils.data.DataLoader(dataset_valid,  # 生成dataloader
                                               batch_size=1, shuffle=False,
                                               num_workers=0)  # 16)  # 警告页面文件太小时可改为0
    print("valid_dataloader_ok")
    print(len(valid_loader))
    for batch_index, (data, target) in tqdm(enumerate(valid_loader)):
        name = dataset_valid.dataFile_path_bboxes[batch_index]
        print('name:', name)

        print('image size ......')
        print(data.shape)  # torch.Size([batch, 1, 16, 96, 96])

        print('label size ......')
        print(target.shape)  # torch.Size([2])

        # 按着batch进行显示
        for i in range(data.shape[0]):
            onePatch = data[i, 0, :, :]
            onePatch_target = target[0, :, :, :]
            print('one_patch:', onePatch.shape, np.max(onePatch.numpy()), np.min(onePatch.numpy()))

            row_num = 6
            column_num = 8
            fig, ax = plt.subplots(row_num, column_num, figsize=[14, 16])
            for m in range(row_num):
                for n in range(column_num):
                    one_pic = onePatch[i * m + n]
                    img = one_pic.numpy()*255.0
                    # print('one_pic img:', one_pic.shape, np.max(one_pic.numpy()), np.min(one_pic.numpy()))

                    one_mask = onePatch_target[i * m + n]
                    contours = getContours(one_mask)
                    for contour in contours:
                        x, y, w, h = cv2.boundingRect(contour)
                        xmin, ymin, xmax, ymax = x, y, x + w, y + h
                        # print('contouts:', xmin, ymin, xmax, ymax)
                        cv2.drawContours(img, contour, -1, (0, 0, 255), 2)
                        # cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255),
                        #               thickness=1)

                    ax[m, n].imshow(img, cmap='gray')
                    ax[m, n].axis('off')


            # print('one_target:', onePatch.shape, np.max(onePatch.numpy()), np.min(onePatch.numpy()))
            fig, ax = plt.subplots(row_num, column_num, figsize=[14, 16])
            for m in range(row_num):
                for n in range(column_num):
                    one_pic = onePatch_target[i * m + n]
                    # print('one_pic mask:', one_pic.shape, np.max(one_pic.numpy()), np.min(one_pic.numpy()))

                    ax[m, n].imshow(one_pic, cmap='gray')
                    ax[m, n].axis('off')
            plt.show()

The displayed image looks like this:

Insert image description here
You can look at a few more pictures. The more you look at them, the more you can verify whether there are any problems with the nodule cropping. At the same time, you can also use a training model to see how many positive samples with nodules and samples that are all black and have no nodules account for the training situation. This also provides a reference standard for us to modify the above code.

4. Summary

This article is actually a summary of the previous blog data flow problem and a solution to the problem. At the same time, it shows a process of verifying the amount of data, which is very beneficial for us to follow up with other tasks.

If you are a beginner, I believe it will be very rewarding. If you came here for a project, you must have found an idea. The difference in data sets is mainly reflected in pre-processing, and in the training stage, this article can help you get started quickly.

Finally, leave your likes and favorites. If you have any questions, please leave comments and private messages. The training and verification code will be introduced later, and this part is also the focus.

Guess you like

Origin blog.csdn.net/wsLJQian/article/details/133963731