In this previous article:[3D Image Segmentation] VNet 3D Image Segmentation 2 based on Pytorch (Basic Data Flow) At the end, we mentioned the following problem encountered during the training phase:
In the segmentation training task of data using the vent
model, the input size is , and this is cropped It is placed in the class and cut out and . However, several problems were discovered during training:3d
16*96*96
Dataset
image
mask
- It took a long time to load the data. It took 30 minutes from the start of training to the official printing and batch cycle.
batch=64, torch.utils.data.DataLoader
Insidenum_workers=8
, when training always reaches a multiple of 8, you have to wait for a long time- When 4 GPUs are trained in parallel, the GPU utilization is 0 for a long time, occasionally rises, and then reaches 0 again in an instant.
free -m
Check the memory usage of and find thatbuff
andcache
will gradually increase and gradually approach full capacity.
When this happens, what is the problem? The model is trained normally and converges well, but it is too slow. AnalyzingmyDataset
the data reading code, there are several places that may be more time-consuming and occupy memory:
getAnnotations
function needs to obtain the file name and the corresponding coordinates of the nodule from thecsv
file, and finally store it as a dictionary, which always takes up memory space;getNpyFile_Path
functions,dataFile_paths
andlabelFile_paths
both need to be called, some are repeated, and the occupancy of this part can be doubled;get_annos_label
Functions have the same problem, some are repeated, and the occupancy of this part can be doubled.
The above functions are all completed in the class__init__
stage. This kind of multiple cycles may be at the beginningbatch
This part of the time before the loop is the main reason for time-consuming; secondly, due to repeated occupation of memory, performance degradation is further aggravated, making subsequent training slower.
In order to solve the above problems, the data loading version of this article was produced. The biggest change is to obtain the original data from the csv file. The form of nodule coordinates is changed to be obtained from the npy file. In this way, are all single files with one-to-one correspondence. From the subsequent actual training, we found that this is indeed the case. This time-consuming problem was solved and the training became faster. 2.0
Dataset
image、mask、Bbox
So, as long as we simplify the determined value and reduce the memory usage in the__init__
stage, this problem should be perfectly solved. Therefore, this article follows this principle and tries to discard as much as possible in the data preprocessing stage, leaving only the simplest one-to-one structure. Put preprocessing in front to avoid calling it during the data construction phase.
LUNA16
For data preprocessing, you can refer to here. The data generated in this article is as follows:
- [3D Image Segmentation] VNet 3D Image Segmentation 6 based on Pytorch (data preprocessing)
- [3D Image Segmentation] VNet 3D Image Segmentation 7 based on Pytorch (data preprocessing)
- [3D Image Segmentation] VNet 3D Image Segmentation 8 based on Pytorch (CT lung parenchyma segmentation)
- [3D Image Segmentation] VNet 3D Image Segmentation 9 based on Pytorch (patch crop and merge operations)
1. Set up a data flow framework
Inpytorch
, the data flow for training follows the following structure. The main idea is as follows:
- In
__init__
, it is executed during the class initialization phase. Here you need to determine a certain value to obtain all the content needed for training, but occupy as little content and time as possible; - In
__getitem__
, an image and label information will be obtained according to the value determined by__init__
, and operations such as reading and enhancement will be performed. Finally return the Tensor value; __len__
returns the length of aepoch
training determined value.
The following is a simple framework structure, which is reserved for reference and can be supplemented here in subsequent construction of data flows.
class myDataset_v3(Dataset):
def __init__(self, data_dir, isTrain=True):
self.data = []
if isTrain:
self.data ···
else:
self.data ···
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# ********** get file dir **********
image, label = self.data[index] # get whole data for one subject
# ********** change data type from numpy to torch.Tensor **********
image = torch.from_numpy(image).float()
label = torch.from_numpy(label).float()
return image, label
In this article, the parameters in this class are introduced in detail. If you are interested, you can go directly to learn:[BraTS] Brain Tumor Segmentation Brain Tumor Split 3 (Building Data Flow)
2. Improve the content of the framework
I believe that through the introduction of the previous four blogs6、7、8、9
, you have processed the original data set ofLuna16
into one-to-one correspondence. We train Required data formats include:
_bboxes.npy
: Record the coordinates and radius of the nodule center point;_clean.nrrd
:CT original image array;_mask.nrrd
: Label file mask array, the same as_clean.nrrd
’sshape
;
also includes some others.npy
, which record some quantities during the entire transformation phase and are not used in the training phase, so they will not be expanded upon here. The most concerning ones are the above three files, and they correspond one to one according toseriesUID
.
If this is the case, we construct the myDataset_v3(Dataset)
data volume and think about: In the __init__
stage, which one can be used as the anchor point, as few as possible In the case of occupying memory, can the required images and annotation information be obtained sequentially in the __getitem__
stage?
That is the file name of seriesUID
. It can be dragged one to three, and a list is enough. This is the most memory-saving way. So our definition in the__init__
stage is as follows:
class myDataset_v3(Dataset):
def __init__(self, data_dir, crop_size=(16, 96, 96), isTrain=False):
self.bboxesFile_path = []
for file in os.listdir(data_dir):
if '_bboxes.npy' in file:
self.bboxesFile_path.append(os.path.join(data_dir, file))
self.crop_size = crop_size
self.crop_size_z, self.crop_size_h, self.crop_size_w = crop_size
self.isTrain = isTrain
Then the definition of __len__
is naturally known, as follows:
def __len__(self):
return len(self.bboxesFile_path)
The most important and difficult thing is the definition of__getitem__
. Here you need to do a few things:
- Get the path of each file;
- Get the data corresponding to the file;
- Crop out the target
patch
; - Several combinations
Tensor
.
Then, in the definition__getitem__
, a problem was discovered, as follows:
def __getitem__(self, index):
bbox_path = self.bboxesFile_path[index]
img_path = bbox_path.replace('_bboxes.npy', '_clean.nrrd')
label_path = bbox_path.replace('_bboxes.npy', '_mask.nrrd')
img, img_shape = self.load_img(img_path)
label = self.load_mask(label_path)
zyx_centerCoor = self.getBboxes(bbox_path)
def getBboxes(self, bboxFile_path):
bboxes_array = np.load(bboxFile_path, allow_pickle=True)
bboxes_list = bboxes_array.tolist()
xyz_list = [[zyx[0], zyx[2], zyx[1]] for zyx in bboxes_list]
return random.choice(xyz_list)
Mainly because the nodule coordinate point recorded by one_bboxes.npy
is not just one nodule. If you put the obtained bbox
into __getitem__
, you will find that it can only cut out one patch
at a time, and it is impossible to cut out many Every nodule situation is dealt with. So I used the random.choice
method here to randomly select a nodule.
However, this method is not good because it will reduce the number of nodules appearing in the learning process. Although it is random, it is equivalent to reducing the amount of certain types of data. Under the same number of learningepoch
, those with only one nodule will be learned relatively more times.
In order to solve this problem, the number of nodules is directly matched with the file name, so that the opportunity for each nodule is equal. The code looks like this:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
import nrrd
import cv2
class myDataset_v3(Dataset):
def __init__(self, data_dir, crop_size=(16, 96, 96), isTrain=False):
self.dataFile_path_bboxes = []
for file in os.listdir(data_dir):
if '_bboxes.npy' in file:
one_path_bbox_list = self.getBboxes(os.path.join(data_dir, file))
self.dataFile_path_bboxes.extend(one_path_bbox_list)
self.crop_size = crop_size
self.crop_size_z, self.crop_size_h, self.crop_size_w = crop_size
self.isTrain = isTrain
def __getitem__(self, index):
bbox_path, zyx_centerCoor = self.dataFile_path_bboxes[index]
img_path = bbox_path.replace('_bboxes.npy', '_clean.nrrd')
label_path = bbox_path.replace('_bboxes.npy', '_mask.nrrd')
img, img_shape = self.load_img(img_path)
# print('img_shape:', img_shape)
label = self.load_mask(label_path)
# print('zyx_centerCoor:', zyx_centerCoor)
cutMin_list = self.getCenterScope(img_shape, zyx_centerCoor)
if self.isTrain:
rd = random.random()
if rd > 0.5:
cut_list = [cutMin_list[0], cutMin_list[0]+self.crop_size_z, cutMin_list[1], cutMin_list[1]+self.crop_size_h, cutMin_list[2], cutMin_list[2]+self.crop_size_w] ### z,y,x
start1, start2, start3 = self.random_crop_around_nodule(img_shape, cut_list, crop_size=self.crop_size, leftTop_ratio=0.3)
elif rd > 0.1:
start1, start2, start3 = self.random_crop_negative_nodule(img_shape, crop_size=self.crop_size)
else:
start1, start2, start3 = cutMin_list
else:
start1, start2, start3 = cutMin_list
img_crop = img[start1:start1 + self.crop_size_z, start2:start2 + self.crop_size_h,
start3:start3 + self.crop_size_w]
label_crop = label[start1:start1 + self.crop_size_z, start2:start2 + self.crop_size_h,
start3:start3 + self.crop_size_w]
# print('before:', img_crop.shape, label_crop.shape)
# 计算需要pad的大小
if img_crop.shape != self.crop_size:
pad_width = [(0, self.crop_size_z-img_crop.shape[0]), (0, self.crop_size_h-img_crop.shape[1]), (0, self.crop_size_w-img_crop.shape[2])]
img_crop = np.pad(img_crop, pad_width, mode='constant', constant_values=0)
if label_crop.shape != self.crop_size:
pad_width = [(0, self.crop_size_z-label_crop.shape[0]), (0, self.crop_size_h-label_crop.shape[1]), (0, self.crop_size_w-label_crop.shape[2])]
label_crop = np.pad(label_crop, pad_width, mode='constant', constant_values=0)
# print('after:', img_crop.shape, label_crop.shape)
img_crop = np.expand_dims(img_crop, 0) # (1, 16, 96, 96)
img_crop = torch.from_numpy(img_crop).float()
label_crop = torch.from_numpy(label_crop).long() # (16, 96, 96) label不用升通道维度
return img_crop, label_crop
def __len__(self):
return len(self.dataFile_path_bboxes)
def load_img(self, path_to_img):
if path_to_img.startswith('LKDS'):
img = np.load(path_to_img)
else:
img, _ = nrrd.read(path_to_img)
img = img.transpose((0, 2, 1)) # 与xyz坐标变换对应
return img/255.0, img.shape
def load_mask(self, path_to_mask):
mask, _ = nrrd.read(path_to_mask)
mask[mask>1] = 1
mask = mask.transpose((0, 2, 1)) # 与xyz坐标变换对应
return mask
def getBboxes(self, bboxFile_path):
bboxes_array = np.load(bboxFile_path, allow_pickle=True)
bboxes_list = bboxes_array.tolist()
one_path_bbox_list = []
for zyx in bboxes_list:
xyz = [zyx[0], zyx[2], zyx[1]]
one_path_bbox_list.append([bboxFile_path, xyz])
return one_path_bbox_list
def getCenterScope0(self, img_shape, zyx_centerCoor):
cut_list = [] # 切割需要用的数
for i in range(len(img_shape)): # 0, 1, 2 → z,y,x
if i == 0: # z
a = zyx_centerCoor[-i - 1] - self.crop_size_z/2 # z
b = zyx_centerCoor[-i - 1] + self.crop_size_z/2 # y,z
else: # y, x
a = zyx_centerCoor[-i - 1] - self.crop_size_w/2
b = zyx_centerCoor[-i - 1] + self.crop_size_w/2
# 超出图像边界 1
if a < 0:
a = self.crop_size_z
b = self.crop_size_w
# 超出边界 2
elif b > img_shape[i]:
if i == 0:
a = img_shape[i] - self.crop_size_z
b = img_shape[i]
else:
a = img_shape[i] - self.crop_size_w
b = img_shape[i]
else:
pass
cut_list.append(int(a))
cut_list.append(int(b))
return cut_list
def getCenterScope(self, img_shape, zyx_centerCoor):
img_z, img_y, img_x = img_shape
zc, yc, xc = zyx_centerCoor
zmin = max(0, zc - self.crop_size_z // 3)
ymin = max(0, yc - self.crop_size_h // 2)
xmin = max(0, xc - self.crop_size_w // 2)
cutMin_list = [int(zmin), int(ymin), int(xmin)]
return cutMin_list
def random_crop_around_nodule(self, img_shape, cut_list, crop_size=(16, 96, 96), leftTop_ratio=0.3):
"""
:param img:
:param label:
:param center:
:param radius:
:param cut_list:
:param crop_size:
:param leftTop_ratio: 越大,阴性样本越多(需要考虑crop_size)
:return:
"""
img_z, img_y, img_x = img_shape
crop_z, crop_y, crop_x = crop_size
z_min, z_max, y_min, y_max, x_min, x_max = cut_list
# print('z_min, z_max, y_min, y_max, x_min, x_max:', z_min, z_max, y_min, y_max, x_min, x_max)
z_min = max(0, int(z_min-crop_z*leftTop_ratio))
z_max = min(img_z, int(z_min + crop_z*leftTop_ratio))
y_min = max(0, int(y_min-crop_y*leftTop_ratio))
y_max = min(img_y, int(y_min+crop_y*leftTop_ratio))
x_min = max(0, int(x_min-crop_x*leftTop_ratio))
x_max = min(img_x, int(x_min+crop_x*leftTop_ratio))
z_start = random.randint(z_min, z_max)
y_start = random.randint(y_min, y_max)
x_start = random.randint(x_min, x_max)
return z_start, y_start, x_start
def random_crop_negative_nodule(self, img_shape, crop_size=(16, 96, 96), boundary_ratio=0.5):
img_z, img_y, img_x = img_shape
crop_z, crop_y, crop_x = crop_size
z_min = 0#crop_z*boundary_ratio
z_max = img_z-crop_z#img_z - crop_z*boundary_ratio
y_min = 0#crop_y*boundary_ratio
y_max = img_y-crop_y#img_y - crop_y*boundary_ratio
x_min = 0#crop_x*boundary_ratio
x_max = img_x-crop_x#img_x - crop_x*boundary_ratio
z_start = random.randint(z_min, z_max)
y_start = random.randint(y_min, y_max)
x_start = random.randint(x_min, x_max)
return z_start, y_start, x_start
The above is the complete code of the new data flow after this rewrite, without adding data enhancement operations. During training, three types of diversity are introduced:
- Ensure that if
mask
has a nodule target, randomly change the position of the nodule inpatch
; - The entire image is randomly cropped, mainly to generate negative samples;
- Cut directly using the nodule as the center point.
The purpose of this is actually to consider the location of the nodule in the patch, which may affect the final prediction. Because in the final inference stage we use, we actually don’t know where the nodule is in the image. We can only traverse all the patches, and then splice the predicted results into a complete mask, and then process the mask. location of all nodules.
This requires that no matter where the nodule appears in the image, it needs to be found with as few false positives as possible.
This is something that I rarely see covered in papers. I don’t know if the paper is only about indicators and forgets about the additional product of false positives. Also, the way to obtain these patches is to cut them out in advance and directly read the patch array for training. This kind of thing is not good either. It’s not diverse enough and it’s quite troublesome.
What we will also talk about in this section are the two functions getCenterScope
and random_crop_around_nodule
. Why is divisible in getCenterScope
? This is because I checked it many times and summarized it. If it is an integer division, it will be found that all the nodules are downward. The reason for this has not been understood yet. If you know, please leave a message. 3
2
If it is a two-dimensional plane and the center point is known, then find the minimum value of the upper left corner, which should be the coordinates of the center point, minus half the width and height. However, when the z
axis is also used minus one-half, it is found that all the cropped nodules are very low.
So, here we subtract one-third to move it up a little on the z-axis. I still haven’t figured out the question here, so if you know, please give me some advice in the comments section.
random_crop_around_nodule
It controls the coordinates of the minimum and maximum values of the upper left corner of the clipping, and is randomly determined within this interval, thus making the clipping of nodules more diverse. As shown below:
I just want the nodules to appear in every cut, and I only need the coordinates of the upper left corner of the nodules to fall within a certain range. leftTop_ratio
The parameter is used to control the distance from the upper left corner point to the upper left corner.
This value needs to be determined by yourself based on the size ofpatch
. It is important to check it multiple times.
3. Verify data flow
Constructing a class function with a large amount of data is not finished yet. Because you don't know whether the data flow at this time meets your requirements. So it would be great if we could simulate the training process and see the results of eachpatch
in advance.
This chapter is for this purpose. Let’s type out the images and masks to see if there are any problems. The viewing method is also relatively simple, and you can copy it and use it in your own projects later.
def getContours(output):
img_seged = output.numpy().astype(np.uint8)
img_seged = img_seged * 255
# ---- Predict bounding box results with txt ----
kernel = np.ones((5, 5), np.uint8)
img_seged = cv2.dilate(img_seged, kernel=kernel)
_, img_seged_p = cv2.threshold(img_seged, 127, 255, cv2.THRESH_BINARY)
try:
_, contours, _ = cv2.findContours(np.uint8(img_seged_p), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
except:
contours, _ = cv2.findContours(np.uint8(img_seged_p), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
return contours
if __name__=='__main__':
data_dir = r"./valid"
dataset_valid = myDataset_v3(data_dir, crop_size=(48, 96, 96), isTrain=False) # 送入dataset
valid_loader = torch.utils.data.DataLoader(dataset_valid, # 生成dataloader
batch_size=1, shuffle=False,
num_workers=0) # 16) # 警告页面文件太小时可改为0
print("valid_dataloader_ok")
print(len(valid_loader))
for batch_index, (data, target) in tqdm(enumerate(valid_loader)):
name = dataset_valid.dataFile_path_bboxes[batch_index]
print('name:', name)
print('image size ......')
print(data.shape) # torch.Size([batch, 1, 16, 96, 96])
print('label size ......')
print(target.shape) # torch.Size([2])
# 按着batch进行显示
for i in range(data.shape[0]):
onePatch = data[i, 0, :, :]
onePatch_target = target[0, :, :, :]
print('one_patch:', onePatch.shape, np.max(onePatch.numpy()), np.min(onePatch.numpy()))
row_num = 6
column_num = 8
fig, ax = plt.subplots(row_num, column_num, figsize=[14, 16])
for m in range(row_num):
for n in range(column_num):
one_pic = onePatch[i * m + n]
img = one_pic.numpy()*255.0
# print('one_pic img:', one_pic.shape, np.max(one_pic.numpy()), np.min(one_pic.numpy()))
one_mask = onePatch_target[i * m + n]
contours = getContours(one_mask)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
xmin, ymin, xmax, ymax = x, y, x + w, y + h
# print('contouts:', xmin, ymin, xmax, ymax)
cv2.drawContours(img, contour, -1, (0, 0, 255), 2)
# cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255),
# thickness=1)
ax[m, n].imshow(img, cmap='gray')
ax[m, n].axis('off')
# print('one_target:', onePatch.shape, np.max(onePatch.numpy()), np.min(onePatch.numpy()))
fig, ax = plt.subplots(row_num, column_num, figsize=[14, 16])
for m in range(row_num):
for n in range(column_num):
one_pic = onePatch_target[i * m + n]
# print('one_pic mask:', one_pic.shape, np.max(one_pic.numpy()), np.min(one_pic.numpy()))
ax[m, n].imshow(one_pic, cmap='gray')
ax[m, n].axis('off')
plt.show()
The displayed image looks like this:
You can look at a few more pictures. The more you look at them, the more you can verify whether there are any problems with the nodule cropping. At the same time, you can also use a training model to see how many positive samples with nodules and samples that are all black and have no nodules account for the training situation. This also provides a reference standard for us to modify the above code.
4. Summary
This article is actually a summary of the previous blog data flow problem and a solution to the problem. At the same time, it shows a process of verifying the amount of data, which is very beneficial for us to follow up with other tasks.
If you are a beginner, I believe it will be very rewarding. If you came here for a project, you must have found an idea. The difference in data sets is mainly reflected in pre-processing, and in the training stage, this article can help you get started quickly.
Finally, leave your likes and favorites. If you have any questions, please leave comments and private messages. The training and verification code will be introduced later, and this part is also the focus.