nnUnet: 使用自己(自定义)网络训练

一、简介

nnUnet是有监督的医学图像分割绕不开的话题(虽然看到有些文章对比实验没加nnUnet?可能比不过? ),其卓越的性能和简易的方法,为相关研究者提供了一项强有力的工具。然而,由于高度封装性,在原先代码中嵌入自定义网络进行训练,并不是十分方便(至少对我来说 )。本文旨在分享一点在使用nnUnet训练自定义网络过程中的一点经验,可能存在纰漏(???),欢迎在讨论区交流!

二、准备工作

2.1 硬件需求

nnUnet的建议环境是Linux,若使用Windows,需修改路径相关代码(斜杠和反斜杠的替换),很麻烦(不推荐)。博主是在Ubuntu环境中使用Pycharm进行nnUnet的学习。

2.2 调试环境

nnUnet官方推荐的使用方法是在命令行,但这不方便初学者学习。因为只用过Pycharm调试代码(菜!),所以为了满足自己的需求(?),以便于用Pycharm的傻瓜式调试按钮,修改了部分代码: nnunet/paths.pynnunet/run/run_training.py

2.2.1 路径

位于***nnunet/paths.py***文件中,将三个变量路径修改为自己的路径。custom_是博主自己定义的文件,大家可以随意实现。

from custom_ import custom_config
base = custom_config['base']
preprocessing_output_dir = custom_config['preprocessing_output_dir']
network_training_output_dir_base = custom_config['network_training_output_dir_base']

2.2.2 parser

位于***nnunet/run/run_training.py***文件中,这里nnUnet训练代码的入口(!!!)。由于不是命令行调用方式,需要将parser进行修改,添加 “-” 并设置 default 值。

    parser = argparse.ArgumentParser()
    parser.add_argument("-network", default='2d')
    parser.add_argument("-network_trainer", default='nnUNetTrainerV2')
    parser.add_argument("-task", default='666', help="can be task name or task id")
    parser.add_argument("-fold", default='0', help='0, 1, ..., 5 or \'all\'')

三、训练

3.1 构建网络

nnUnet要求网络继承 SegmentationNetwork 类,这里提供一种可实现的方法,用的时候将 self.model 修改为 自定义网络 即可。

from nnunet.network_architecture.neural_network import SegmentationNetwork


class custom_net(SegmentationNetwork):

    def __init__(self, num_classes):
        super(custom_net, self).__init__()
        self.params = {
    
    'content': None}
        self.conv_op = nn.Conv2d
        self.do_ds = True
        self.num_classes = num_classes
        
		######## self.model 设置自定义网络 by Sleeep ########
        self.model = None
        ######## self.model 设置自定义网络 by Sleeep ########
        
        self.name = self.model.name

    def forward(self, x):

        if self.do_ds:
            return [self.model(x), ]
        else:
            return self.model(x)


def create_model():

    return custom_net(num_classes=2)

3.2 修改配置

构建好网络后,还需要修改一些超参数才能完成训练,修改内容位于 /nnunet/training/network_training/nnUNetTrainerV2.py 文件中。修改nnUNetTrainerV2类中的两个函数 initializeinitialize_network。为了减少训练代数,还可修改函数***init***

3.2.1 initialize

def initialize(self, training=True, force_load_plans=False):
        """
        - replaced get_default_augmentation with get_moreDA_augmentation
        - enforce to only run this code once
        - loss function wrapper for deep supervision
        :param training:
        :param force_load_plans:
        :return:
        """
        if not self.was_initialized:
            maybe_mkdir_p(self.output_folder)

            if force_load_plans or (self.plans is None):
                self.load_plans_file()
            # load plan informantion !!!   modify batch_size or patch_size after this
            self.process_plans(self.plans)
            ############## modify para by Sleeep ##############
            self.patch_size = np.array(custom_config['patch_size']).astype(int)
            self.batch_size = custom_config['batch_size']
            self.net_num_pool_op_kernel_sizes = [[2, 2]]
            ############## modify para by Sleeep ##############

self.process_plans(self.plans): 官方函数,会载入预处理阶段所生成的各种参数
self.patch_size: 官方预处理后的 图像尺寸不一定满足自定义网络的需要,可在这里修改。举个例子,在预处理阶段,nnUnet自动确定的patch_size为[53, 64](对于2d网络),然而我的网络需要满足输入尺寸均为 32 的整数倍,自动生成的patch_size并不能满足,所以这里可修改为[64, 64]。self.patch_size会在后续构建 数据增强方法 的函数中使用。
self.batch_size: 根据自己的硬件配置修改
self.net_num_pool_op_kernel_sizes: 这个参数非常重要!其作用是 确定 深监督的层数 和 不同层数的尺寸大小。这里默认自定义网络不使用深监督,所以设置 为 只有一个列表元素即可,里面的值随意(可能吧?对于不是使用深监督的情况

3.2.2 initialize_network

 def initialize_network(self):
        """
        - momentum 0.99
        - SGD instead of Adam
        - self.lr_scheduler = None because we do poly_lr
        - deep supervision = True
        - i am sure I forgot something here
        Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
        :return:
        """
        # if self.threeD:
        #     conv_op = nn.Conv3d
        #     dropout_op = nn.Dropout3d
        #     norm_op = nn.InstanceNorm3d
        #
        # else:
        #     conv_op = nn.Conv2d
        #     dropout_op = nn.Dropout2d
        #     norm_op = nn.InstanceNorm2d
        #
        # norm_op_kwargs = {'eps': 1e-5, 'affine': True}
        # dropout_op_kwargs = {'p': 0, 'inplace': True}
        # net_nonlin = nn.LeakyReLU
        # net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        # self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
        #                             len(self.net_num_pool_op_kernel_sizes),
        #                             self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
        #                             dropout_op_kwargs,
        #                             net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
        #                             self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
        ############## add custom model by Sleeep ##############
        self.network = create_model()
        ############## add custom model by Sleeep ##############
        if torch.cuda.is_available():
            self.network.cuda()
        self.network.inference_apply_nonlin = softmax_helper

将原先nnunet网络注释,加入自己的网络。

3.2.3 init

    def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
                 unpack_data=True, deterministic=True, fp16=False):
        super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
                         deterministic, fp16)
        ##### by Sleeep ####
        self.max_num_epochs = custom_config['epoch']
        self.initial_lr = custom_config['lr']
        ##### by Sleeep ####
        
        self.deep_supervision_scales = None
        self.ds_loss_weights = None

        self.pin_memory = True

默认的epoch是1000,有点久,改小点。默认lr是0.01

四、注意事项

  1. 点赞 收藏 评论 再走
  2. 先过一遍nnunet论文和官方github中的说明。官方提供的一个example ,其中使用的数据集相对来说较小,可先在该数据集上调试成功后,再使用自定义网络

参考资料

  1. nnUnet

猜你喜欢

转载自blog.csdn.net/qq_42811827/article/details/127632891