SOLO实战——用自己的数据集训练实例分割模型

论文:https://arxiv.org/abs/1912.04488

代码:https://github.com/WXinlong/SOLO

1.配置训练环境

    小胖墩是用conda配置的训练环境,用一下几行命令即可配置成功:

conda create -n solo python=3.7 -y
conda activate solo

conda install -c pytorch pytorch torchvision -y
conda install cython -y
git clone https://github.com/WXinlong/SOLO.git
cd SOLO
pip install -r requirements/build.txt
pip install "git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI"
pip install -v -e .

       当然,如果你想用其他方式配置训练环境,参照官网

2.准备数据集

     建议将自己的数据集准备成coco的数据格式,这样代码的改用量会很少。

     关于coco数据集的格式,请参考我的博客:https://blog.csdn.net/Guo_Python/article/details/105839280

3. 修改代码

      1. 注册一下自己的数据集

          在mmdet/datasets/ 目录下创建Your_dataset.py 文件,内容如下(继承了CocoDataset):

from .coco import CocoDataset
from .registry import DATASETS

#add new dataset
@DATASETS.register_module
class Your_Dataset(CocoDataset):
    CLASSES = ['people', 'dog', 'cat']

           在mmdet/datasets/__init__.py 将该数据格式添加进去,修改后的__init__.py如下:

from .builder import build_dataset
from .cityscapes import CityscapesDataset
from .coco import CocoDataset
from .custom import CustomDataset
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .loader import DistributedGroupSampler, GroupSampler, build_dataloader
from .registry import DATASETS
from .voc import VOCDataset
from .wider_face import WIDERFaceDataset
from .xml_style import XMLDataset
from .my_dataset import MyDataset
from .Your_dataset import Your_Dataset

__all__ = [
    'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset',
    'CityscapesDataset', 'GroupSampler', 'DistributedGroupSampler',
    'build_dataloader', 'ConcatDataset', 'RepeatDataset', 'WIDERFaceDataset',
    'DATASETS', 'build_dataset', 'MyDataset', 'Your_Dataset'
]

       2. 修改配置文件 

           configs/solo/solo_r50_fpn_8gpu_3x.py, 修改后的内容如下:主要修改数据路径,训练尺寸和检测类别。

# model settings
model = dict(
    type='SOLO',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3), # C2, C3, C4, C5
        frozen_stages=1,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        start_level=0,
        num_outs=5),
    bbox_head=dict(
        type='SOLOHead',
        num_classes=4,   # 修改类别,种类+背景     
        in_channels=256,
        stacked_convs=4,
        seg_feat_channels=256,
        strides=[8, 8, 16, 32, 32],
        scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
        sigma=0.2,
        num_grids=[40, 36, 24, 16, 12],
        cate_down_pos=0,
        with_deform=False,
        loss_ins=dict(
            type='DiceLoss',
            use_sigmoid=True,
            loss_weight=3.0),
        loss_cate=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),
    ))
# training and testing settings
train_cfg = dict()
test_cfg = dict(
    nms_pre=500,
    score_thr=0.1,
    mask_thr=0.5,
    update_thr=0.05,
    kernel='gaussian',  # gaussian/linear
    sigma=2.0,
    max_per_img=100)
# dataset settings
dataset_type = 'Your_Dataset'         # 修改数据格式
data_root = '/home/gp/dukto/Data/'    # 修改数据路径
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(type='Resize',
         # 修改图片尺寸
         img_scale=[(640, 480), (640, 420), (640, 400),
                    (640, 360), (640, 320), (640, 300)],
         multiscale_mode='value',
         keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(640, 480),    # 修改图片尺寸
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    imgs_per_gpu=4,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'train.json',
        img_prefix=data_root + 'images/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/test.json',
        img_prefix=data_root + 'images/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/test.json',
        img_prefix=data_root + 'images/',
        pipeline=test_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=1.0 / 3,
    step=[27, 33])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
# runtime settings
total_epochs = 36
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/solo_release_r50_fpn_3x'     # 模型和训练日志存放地址
load_from = None
resume_from = None
workflow = [('train', 1)]

4. 训练模型

      训练命令:

python tools/train.py configs/solo/solo_r50_fpn_8gpu_3x.py

      不出意外,你应该开始训练你的模型了,训练日志如下:

2020-06-06 08:59:44,359 - mmdet - INFO - Start running, host: gp@gp-System-Product-Name, work_dir: /home/gp/work/project/SOLO/work_dirs/solo_release_r50_fpn_3x
2020-06-06 08:59:44,360 - mmdet - INFO - workflow: [('train', 1)], max: 36 epochs
2020-06-06 09:00:16,755 - mmdet - INFO - Epoch [1][50/1077]	lr: 0.00399, eta: 6:58:04, time: 0.648, data_time: 0.012, memory: 4378, loss_ins: 2.9387, loss_cate: 0.8999, loss: 3.8386
2020-06-06 09:00:49,382 - mmdet - INFO - Epoch [1][100/1077]	lr: 0.00465, eta: 6:59:03, time: 0.653, data_time: 0.009, memory: 4404, loss_ins: 2.9341, loss_cate: 0.7440, loss: 3.6781
2020-06-06 09:01:23,307 - mmdet - INFO - Epoch [1][150/1077]	lr: 0.00532, eta: 7:04:35, time: 0.678, data_time: 0.009, memory: 4404, loss_ins: 2.9176, loss_cate: 0.6833, loss: 3.6009
2020-06-06 09:01:57,125 - mmdet - INFO - Epoch [1][200/1077]	lr: 0.00599, eta: 7:06:44, time: 0.676, data_time: 0.008, memory: 4404, loss_ins: 2.8258, loss_cate: 0.6749, loss: 3.5007
2020-06-06 09:02:31,695 - mmdet - INFO - Epoch [1][250/1077]	lr: 0.00665, eta: 7:09:43, time: 0.691, data_time: 0.009, memory: 4404, loss_ins: 2.6120, loss_cate: 0.6709, loss: 3.2829
2020-06-06 09:03:05,609 - mmdet - INFO - Epoch [1][300/1077]	lr: 0.00732, eta: 7:10:07, time: 0.678, data_time: 0.008, memory: 4404, loss_ins: 2.3819, loss_cate: 0.6161, loss: 2.9980
2020-06-06 09:03:41,319 - mmdet - INFO - Epoch [1][350/1077]	lr: 0.00799, eta: 7:13:31, time: 0.714, data_time: 0.009, memory: 4404, loss_ins: 2.2753, loss_cate: 0.5814, loss: 2.8567
2020-06-06 09:04:16,615 - mmdet - INFO - Epoch [1][400/1077]	lr: 0.00865, eta: 7:15:16, time: 0.706, data_time: 0.009, memory: 4404, loss_ins: 2.1902, loss_cate: 0.5767, loss: 2.7669
2020-06-06 09:04:50,585 - mmdet - INFO - Epoch [1][450/1077]	lr: 0.00932, eta: 7:14:37, time: 0.679, data_time: 0.009, memory: 4404, loss_ins: 2.0392, loss_cate: 0.5654, loss: 2.6046
2020-06-06 09:05:23,447 - mmdet - INFO - Epoch [1][500/1077]	lr: 0.00999, eta: 7:12:34, time: 0.657, data_time: 0.008, memory: 4404, loss_ins: 1.9716, loss_cate: 0.5497, loss: 2.5212
2020-06-06 09:05:58,198 - mmdet - INFO - Epoch [1][550/1077]	lr: 0.01000, eta: 7:12:59, time: 0.695, data_time: 0.008, memory: 4404, loss_ins: 1.8707, loss_cate: 0.5178, loss: 2.3885
2020-06-06 09:06:31,807 - mmdet - INFO - Epoch [1][600/1077]	lr: 0.01000, eta: 7:12:01, time: 0.672, data_time: 0.008, memory: 4404, loss_ins: 1.9311, loss_cate: 0.5885, loss: 2.5196
2020-06-06 09:07:05,118 - mmdet - INFO - Epoch [1][650/1077]	lr: 0.01000, eta: 7:10:49, time: 0.666, data_time: 0.009, memory: 4404, loss_ins: 1.8655, loss_cate: 0.5750, loss: 2.4404
2020-06-06 09:07:39,825 - mmdet - INFO - Epoch [1][700/1077]	lr: 0.01000, eta: 7:10:59, time: 0.694, data_time: 0.009, memory: 4404, loss_ins: 1.7929, loss_cate: 0.5022, loss: 2.2951
2020-06-06 09:08:15,143 - mmdet - INFO - Epoch [1][750/1077]	lr: 0.01000, eta: 7:11:34, time: 0.706, data_time: 0.009, memory: 4404, loss_ins: 1.8265, loss_cate: 0.5218, loss: 2.3482

       loss变化曲线如下:

Caption

5. 模型测试

       参照github的代码即可,我的测试结果如下:

 

Caption

end!!! 如有疑问,请留言!!!

猜你喜欢

转载自blog.csdn.net/Guo_Python/article/details/106623798
今日推荐