【实践篇】mmdetection修改自己的config文件

1.环境配置

操作 命令 版本
检查cuda版本 nvcc -V 11.4
根据cuda版本下载pytorch(官网:https://pytorch.org/get-started/locally/) pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 1.10.2
根据pytorch版本下载mmcv pip install mmcv-full==1.4.6 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html 1.4.6
下载mmdetection(mmcv与mmdetection版本要对应:https://github.com/open-mmlab/mmdetection/blob/master/docs/zh_cn/get_started.md) git clone https://github.com/open-mmlab/mmdetection.git cd mmdetection pip install -r requirements/build.txt pip install -e . 2.23.0

2.测试环境是否安装正确

import os
import sys
from mmdet.apis import init_detector,inference_detector, show_result_pyplot
config_file = './configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
device='cuda:0'
model = init_detector(config_file,checkpoint_file,device=device)
img = './demo/demo.jpg'
result = inference_detector(model, img)
show_result_pyplot(model, img, result, score_thr=0.3)

3.打印出你需要的config文件

  • 查看config文件:
python ./tools/misc/print_config.py ./configs/yolox/yolox_l_8x8_300e_coco.py>>yolox_l_8x8_300e_coco.txt

并将其修改为自己的config文件:

  1. 删除首行“Config:”
  2. load_from:预训练模型路径
  3. resume_from:紧接上一次训练的模型路径
  4. num_classes:类别个数
  5. ann_file:json文件路径
  6. img_prefix:jpg文件路径
  7. persistent_workers=False 不开多进程
  8. max_epochs=300
  9. samples_per_gpu:batchsize的个数
  10. workers_per_gpu:多进程的个数,可设为0
  11. work_dir:模型权重及log的保存位置
  • 修改/mmdet/datasets/coco.py:
    CLASSES = (“person”,)
    PALETTE = (220, 20, 60)

  • 修改/mmdet/core/evaluation/class_names.py:
    return [‘person’]

  • 修改/tools/train.py
    validate = True

4.开启多进程

要想给num_workers设置较大的值,必须在创建docker时配置较大的共享内存,加入参数–shm-size=“15g”,这时workers_per_gpu=10,persistent_workers=True。(如有迷惑请看其他博文)

5.开始训练

单gpu训练:

python ./tools/train.py ./configs/yolox/my_yolox_l.py --gpu-id 2

多gpu训练(未成功):

bash ./tools/dist_train.sh ./configs/yolox/my_yolox_l.py 2

6.绘制曲线

评价指标 命令
map python ./tools/analysis_tools/analyze_logs.py plot_curve ./tutorial_exps/yolox_l_8x8_300e_coco/backup_2/20220511_030327.log.json --keys mAP --legend mAP_bbox
loss python ./tools/analysis_tools/analyze_logs.py plot_curve ./tutorial_exps/yolof_r50_c5_8x8_1x_coco/backup_2/20220516_023437.log.json --keys loss_cls loss_bbox --legend loss_cls loss_bbox --start-epoch 2 --out losses.jpg

此处略微修改了一下analyze_logs.py文件,其中注释掉的内容为修改处,详见下面代码,主要目的是让–start-epoch有效。

def plot_curve(log_dicts, args):
    if args.backend is not None:
        plt.switch_backend(args.backend)
    sns.set_style(args.style)
    # if legend is None, use {filename}_{key} as legend
    legend = args.legend
    if legend is None:
        legend = []
        for json_log in args.json_logs:
            for metric in args.keys:
                legend.append(f'{
      
      json_log}_{
      
      metric}')
    assert len(legend) == (len(args.json_logs) * len(args.keys))
    metrics = args.keys
    print("**************",metrics)

    num_metrics = len(metrics)
    for i, log_dict in enumerate(log_dicts):
        epochs = list(log_dict.keys())
        for j, metric in enumerate(metrics):
            print(f'plot curve of {
      
      args.json_logs[i]}, metric is {
      
      metric}')
            if metric not in log_dict[epochs[int(args.start_epoch) - 1]]:
                if 'bbox_mAP' in metric:
                    raise KeyError(
                        f'{
      
      args.json_logs[i]} does not contain metric '
                        f'{
      
      metric}. Please check if "--no-validate" is '
                        'specified when you trained the model.')
                raise KeyError(
                    f'{
      
      args.json_logs[i]} does not contain metric {
      
      metric}. '
                    'Please reduce the log interval in the config so that '
                    'interval is less than iterations of one epoch.')

            if 'bbox_mAP' in metric:
                # xs = np.arange(
                #     int(args.start_epoch),
                #     max(epochs), int(args.eval_interval))
                xs = []
                ys = []
                # for epoch in epochs:
                #     ys += log_dict[epoch][metric]
                print(args.eval_interval)
                print("epochs=",epochs)
                # print(xs,range(int(args.start_epoch),max(epochs)+1, int(args.eval_interval)))
                for x in range(int(args.start_epoch),max(epochs), int(args.eval_interval)):
                    if log_dict[x][metric]:
                        ys += log_dict[x][metric]
                        xs.append(x)
                ax = plt.gca()
                ax.set_xticks(xs)
                plt.xlabel('epoch')
                plt.plot(xs, ys,label=legend[i * num_metrics + j], marker='o')
            else:
                xs = []
                ys = []
                num_iters_per_epoch = log_dict[int(args.start_epoch)]['iter'][-2]#log_dict[epochs[0]]['iter'][-2]
                for x in range(int(args.start_epoch),max(epochs), int(args.eval_interval)):
                    iters = log_dict[x]['iter']
                    if log_dict[x]['mode'][-1] == 'val':
                        iters = iters[:-1]
                    xs.append(
                        np.array(iters) + (x - 1) * num_iters_per_epoch)
                    ys.append(np.array(log_dict[x][metric][:len(iters)]))
                # for epoch in epochs:
                #     iters = log_dict[epoch]['iter']
                #     if log_dict[epoch]['mode'][-1] == 'val':
                #         iters = iters[:-1]
                #     xs.append(
                #         np.array(iters) + (epoch - 1) * num_iters_per_epoch)
                #     ys.append(np.array(log_dict[epoch][metric][:len(iters)]))
                xs = np.concatenate(xs)
                ys = np.concatenate(ys)
                plt.xlabel('iter')
                plt.plot(
                    xs, ys, label=legend[i * num_metrics + j], linewidth=0.5)
            plt.legend()
        if args.title is not None:
            plt.title(args.title)
    if args.out is None:
        plt.show()
    else:
        print(f'save curve to: {
      
      args.out}')
        plt.savefig(args.out)
        plt.cla()

7.过程文件

其中含有my_yolox_l.py,faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth:
链接:https://pan.baidu.com/s/1rEqjj8kzFaRS_3ynXFX2MQ
提取码:4e34

猜你喜欢

转载自blog.csdn.net/weixin_42326479/article/details/124750932