pytorch实战-Unet3d(LiTS)

本文采用Unet3d进行LiTS腹部CT肝脏肿瘤分割

数据集的train集合一共130个样例,都为nii格式,原始CT数据为volume-*.nii,分割的ground truth为segmentation-0.nii,其中0为背景,1为肝脏,2为肿瘤,但是并不是每个样例里边都含有肿瘤

本来是准备用https://blog.csdn.net/py184473894/article/details/88558886这里的keras实现的unet进行这个数据集的分割的,但是后来发现,不知道是我的代码问题,还是keras有bug,在计算肿瘤的dice的时候会计算出错,所以训练不出肿瘤的分割,我在keras的github上提了这个issue,但是还没人回复,如果有大佬可以解决的话,麻烦联系我,或者在issue下边回复

LiTS数据的预处理

https://github.com/assassint2017/MICCAI-LITS2017/blob/master/data_prepare/get_fix_data.py

在这里使用了这个源代码进行,找到包含肝脏或者肿瘤的slice,然后上下取n片,作为训练集合

    def fix_data(self):
        upper = 200
        lower = -200
        expand_slice = 20  # 轴向上向外扩张的slice数量
        size = 48  # 取样的slice数量
        stride = 3  # 取样的步长
        down_scale = 0.5
        slice_thickness = 2

        for ct_file in os.listdir(self.row_root_path + 'data/'):
            print(ct_file)
            # 将CT和金标准入读内存
            ct = sitk.ReadImage(os.path.join(self.row_root_path + 'data/', ct_file), sitk.sitkInt16)
            ct_array = sitk.GetArrayFromImage(ct)

            seg = sitk.ReadImage(os.path.join(self.row_root_path + 'label/', ct_file.replace('volume', 'segmentation')),
                                 sitk.sitkInt8)
            seg_array = sitk.GetArrayFromImage(seg)

            print(ct_array.shape, seg_array.shape)

            # 将金标准中肝脏和肝肿瘤的标签融合为一个
            seg_array[seg_array > 0] = 1

            # 将灰度值在阈值之外的截断掉
            ct_array[ct_array > upper] = upper
            ct_array[ct_array < lower] = lower

            # 找到肝脏区域开始和结束的slice,并各向外扩张
            z = np.any(seg_array, axis=(1, 2))
            start_slice, end_slice = np.where(z)[0][[0, -1]]

            # 两个方向上各扩张个slice
            if start_slice - expand_slice < 0:
                start_slice = 0
            else:
                start_slice -= expand_slice

            if end_slice + expand_slice >= seg_array.shape[0]:
                end_slice = seg_array.shape[0] - 1
            else:
                end_slice += expand_slice

            print(str(start_slice) + '--' + str(end_slice))
            # 如果这时候剩下的slice数量不足size,直接放弃,这样的数据很少
            if end_slice - start_slice + 1 < size:
                print('!!!!!!!!!!!!!!!!')
                print(ct_file, 'too little slice')
                print('!!!!!!!!!!!!!!!!')
                continue

            ct_array = ct_array[start_slice:end_slice + 1, :, :]
            seg_array = sitk.GetArrayFromImage(seg)
            seg_array = seg_array[start_slice:end_slice + 1, :, :]

            new_ct = sitk.GetImageFromArray(ct_array)
            new_seg = sitk.GetImageFromArray(seg_array)

            sitk.WriteImage(new_ct, os.path.join(self.data_root_path + 'data/', ct_file))
            sitk.WriteImage(new_seg,
                            os.path.join(self.data_root_path + 'label/', ct_file.replace('volume', 'segmentation')))

基本上和源代码没有更改 

 

LiTS数据的读取

首先是将130个数据随机分为训练集(0.8)和验证集(0.1)和测试集(0.1)

1、读取volume和segmentation

2、进行scale,将分辨率压缩

3、每个样例随机截取n个(depth,height,width)大小的3维块作为一个输入的batch

4、数据归一化到0-1

5、将读取函数包装为dataset、dataloader

使用的时候主要使用了以下函数

def next_train_batch_3d_sub_by_index(self, train_batch_size, crop_size, index,resize_scale=1):
        train_imgs = np.zeros([train_batch_size, crop_size[0], crop_size[1], crop_size[2], 1])
        train_labels = np.zeros([train_batch_size, crop_size[0], crop_size[1], crop_size[2], self.n_labels])
        img, label = self.get_np_data_3d(self.train_name_list[index],resize_scale=resize_scale)
        for i in range(train_batch_size):
            sub_img, sub_label = util.random_crop_3d(img, label, crop_size)

            sub_img = sub_img[:, :, :, np.newaxis]
            sub_label_onehot = make_one_hot_3d(sub_label, self.n_labels)

            train_imgs[i] = sub_img
            train_labels[i] = sub_label_onehot

        return train_imgs, train_labels

 val集合类似

U-Net3d搭建

这里其实没什么好讲的,主要使用几个模块,resblock,seblock,RecombinationBlock、denseBlock等,然后上采样方式可以选是线性插值或者是deconv

class UNet(nn.Module):
    def __init__(self, in_channels, filter_num_list, class_num, conv_block=RecombinationBlock, net_mode='2d'):
        super(UNet, self).__init__()

        if net_mode == '2d':
            conv = nn.Conv2d
        elif net_mode == '3d':
            conv = nn.Conv3d
        else:
            conv = None

        self.inc = conv(in_channels, 16, 1)

        # down
        self.down1 = Down(16, filter_num_list[0], conv_block=conv_block, net_mode=net_mode)
        self.down2 = Down(filter_num_list[0], filter_num_list[1], conv_block=conv_block, net_mode=net_mode)
        self.down3 = Down(filter_num_list[1], filter_num_list[2], conv_block=conv_block, net_mode=net_mode)
        self.down4 = Down(filter_num_list[2], filter_num_list[3], conv_block=conv_block, net_mode=net_mode)

        self.bridge = conv_block(filter_num_list[3], filter_num_list[4], net_mode=net_mode)

        # up
        self.up1 = Up(filter_num_list[4], filter_num_list[3], filter_num_list[3], conv_block=conv_block,
                      net_mode=net_mode)
        self.up2 = Up(filter_num_list[3], filter_num_list[2], filter_num_list[2], conv_block=conv_block,
                      net_mode=net_mode)
        self.up3 = Up(filter_num_list[2], filter_num_list[1], filter_num_list[1], conv_block=conv_block,
                      net_mode=net_mode)
        self.up4 = Up(filter_num_list[1], filter_num_list[0], filter_num_list[0], conv_block=conv_block,
                      net_mode=net_mode)

        self.class_conv = conv(filter_num_list[0], class_num, 1)

    def forward(self, input):

        x = input

        x = self.inc(x)

        conv1, x = self.down1(x)

        conv2, x = self.down2(x)

        conv3, x = self.down3(x)

        conv4, x = self.down4(x)

        x = self.bridge(x)

        x = self.up1(x, conv4)

        x = self.up2(x, conv3)

        x = self.up3(x, conv2)

        x = self.up4(x, conv1)

        x = self.class_conv(x)

        x = nn.Softmax(1)(x)

        return x

loss metrics记录

之前因为是使用keras,记录loss和metrics太方便了,但是现在用了torch都要自己来写,然后想要使用tensorboard来实现,所以找到了https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/04-utils/tensorboard/logger.py

用这个代码就可以将自己想要的loss和metrics记录下来

logger.scalar_summary('val_loss', val_loss, epoch)
logger.scalar_summary('val_dice0', val_dice0, epoch)
logger.scalar_summary('val_dice1', val_dice1, epoch)
logger.scalar_summary('val_dice2', val_dice2, epoch)

完整代码可以在我的github上找到,代码还在完善中,因为肿瘤的数量较少,所以对肿瘤的分割效果不太好,还在改进中

https://github.com/panxiaobai/lits_pytorch

发布了164 篇原创文章 · 获赞 36 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/py184473894/article/details/96186514
今日推荐