MMSegmentation配置模型参数文件configParameter.py

from mmengine import Config


cfg = Config.fromfile('./configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py')
dataset_cfg = Config.fromfile('./configs/_base_/datasets/WHUDataset_pipeline.py')
cfg.merge_from_dict(dataset_cfg)


# 类别个数
NUM_CLASS = 2
cfg.model.data_preprocessor.size = cfg.crop_size
cfg.model.data_preprocessor.test_cfg = dict(size_divisor=128)

# 单卡训练时,需要把 SyncBN 改成 BN
cfg.norm_cfg = dict(type='SyncBN', requires_grad=True) 
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = NUM_CLASS
cfg.model.auxiliary_head.num_classes = NUM_CLASS

# 训练 Batch Size
cfg.train_dataloader.batch_size = 4

# 结果保存目录
cfg.work_dir = './work_dirs/WHUDataset_PSPNet/'

# 模型保存与日志记录

cfg.train_cfg.max_iters = 60000 # 训练迭代次数

cfg.train_cfg.val_interval = 600 # 评估模型间隔

cfg.default_hooks.logger.interval = 200 # 日志记录间隔

cfg.default_hooks.checkpoint.interval = 2000 # 模型权重保存间隔
cfg.default_hooks.checkpoint.max_keep_ckpts = 1 # 最多保留几个模型权重
cfg.default_hooks.checkpoint.save_best = 'mIoU' # 保留指标最高的模型权重

# 随机数种子
cfg['randomness'] = dict(seed=0)

#保存最终的配置文件
cfg.dump('Zihao-Configs/WHUDataset_PSPNet_20230926.py')

猜你喜欢

转载自blog.csdn.net/ArcGis_Niu/article/details/133317054