detectron2自定义数据集

数据处理:
- 数据文件目录
        - 一级目录:VOC2007
          - 二级目录:Annotations #存放标注后的voc格式文件
          -         ImageSets
            - 三级目录:Main #存放分割后的TXT索引文件
          -         JPEImages #存放图片文件
          -         train_annotations
          -         train_JPEImages
          -         val_annotations
          -         val_JPEImages
          ----      1_split.py
        - 一级目录:coco
          - 二级目录:annotations #
          -         train2017
          -         val2017
        ----      2_mv.py
        ----      3_xml_json.py//3_v2v_2.py
    ----      4_dataset_test.py
    - 训练模型:model_train
        -         output_trainsample
        ----      5_trainsample.py
    将图片放入JPEImages里,标注好的voc文件放入Annotations里
1_split.py #生成索引文件(trainval.txt、test.txt、train.txt、testval.txt)
2_mv.py #图片标记分割,根据1中生成的索引文件,对图片,xml文件进行分割和复制[先将训练集分离出来]。(train_annotations、train_JPEImages、val_annotations、val_JPEImages)
3_xml_json.py//3_v2v_2.py #将voc格式转json格式用于detectron的训练
4_dataset_test.py #数据集测试
5_trainsample.py #训练模型

  --opts MODEL.DEVICE cuda MODEL.WEIGHTS ../model_train/output_Warning_Faster_R-CNN_test/model_final.pth

 目标分割检测:

- dataset_17
    - JPEGImages #存放图片和标注后的json文件
- 2_json-voc.py
- detectron2
    - dataset
        - coco_17coco
            - annotations
                - instances_train2017.json
                - instances_val2017.json
            - train2017 #图片
            - val2017
    - 4_dataset_test.py

(1)-Labelme标注图片,生成json文件
(2)-2_json-voc.py   #生成train和val两个json文件用于detectron训练
        修改自己的类别 classname_to_id = {"clothing":0,……}
        图片路径 labelme_path = './dataset_17/JPEGImages';保存路径saved_coco_path = './dateset/coco_17'
(3)-4_dataset_test.py   #数据集测试
        修改自己的类别 DATASET_CATEGORIES = [{"name": "clothing", "id": 0, "isthing": 1, "color": [220, 20, 60]},……]
        数据集路径 DATASET_ROOT = './datasets/coco_17coco'
(4)-5-2-3_trainsample.py   #训练模型
        修改自己的类别 DATASET_CATEGORIES = [{"name": "clothing", "id": 0, "isthing": 1, "color": [220, 20, 60]},……]
        数据集路径 DATASET_ROOT = './datasets/coco_17coco'
        修改类别数 cfg.MODEL.ROI_HEADS.NUM_CLASSES = 4
        修改训练总批次 cfg.SOLVER.MAX_ITER = 560

1_split.py

'''生成索引文件
划分TXT文件
'''


import os
import random
# 对应位置修改分割比
trainval_percent = 0.1
train_percent = 0.9
xmlfilepath = 'Annotations'
txtsavepath = r'ImageSets\Main'
total_xml = os.listdir(xmlfilepath)

num = len(total_xml)
list = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, tr)

ftrainval = open('ImageSets/Main/trainval.txt', 'w')
ftest = open('ImageSets/Main/test.txt', 'w')
ftrain = open('ImageSets/Main/train.txt', 'w')
fval = open('ImageSets/Main/testval.txt', 'w')

for i in list:
    name = total_xml[i][:-4] + '\n'
    if i in trainval:
        ftrainval.write(name)
        if i in train:
            ftest.write(name)
        else:
            fval.write(name)
    else:
        ftrain.write(name)

ftrainval.close()
ftrain.close()
fval.close()
ftest.close()

2_mv.py

'''图片标记分割
根据1中生成的索引文件,对图片,xml文件进行分割和复制
先将训练集分离出来。'''


import os
import shutil


class CopyXml():
    def __init__(self):
        # 你的xml格式的annotation的路径
        # self.xmlpath = r'VOC2007/Annotations'
        self.jpgpath = r'VOC2007/JPEGImages'
        # 你训练集/测试集xml和jpg存放的路径
        # self.newxmlpath = r'VOC2007/val_annotations'
        self.newjpgpath = r'VOC2007/train_JPEGImages'

    def startcopy(self):
        # filelist = os.listdir(self.xmlpath)  # file list in this directory
        filelist = os.listdir(self.jpgpath)  # file list in this directory
        # print(len(filelist))
        test_list = loadFileList()
        # print(len(test_list))
        for f in filelist:
            # xmldir = os.path.join(self.xmlpath, f)
            jpgdir = os.path.join(self.jpgpath, f)

            (shotname, extension) = os.path.splitext(f)
            if str(shotname) in test_list:
                # print('success')
                # shutil.copyfile(str(xmldir), os.path.join(self.newxmlpath, f))
                shutil.copyfile(str(jpgdir), os.path.join(self.newjpgpath, f))


# load the list of train/test file list
def loadFileList():
    filelist = []
    f = open("VOC2007/ImageSets/Main/train.txt", "r")
    # f = open("VOC2007/ImageSets/Main/trainval.txt", "r")
    lines = f.readlines()
    for line in lines:
        # 去掉文件中每行的结尾字符
        line = line.strip('\r\n')  # to remove the '\n' for test.txt, '\r\n' for tainval.txt
        line = str(line)
        filelist.append(line)
    f.close()
    # print(filelist)
    return filelist


if __name__ == '__main__':
    demo = CopyXml()
    demo.startcopy()

3_v2v_2.py

"""
Created on Tue Jun 12 10:24:36 2018
将voc格式转json格式用于caffe2的detectron的训练
在detectron中voc_2007_train.json和voc_2007_val.json中categories的顺序必须保持一致
因此最好事先确定category的顺序,书写在category_set中
@author: yantianwang
"""
 
import xml.etree.ElementTree as ET
import os
import json
import collections
 
coco = dict()
coco['images'] = []
coco['type'] = 'instances'
coco['annotations'] = []
coco['categories'] = []
 
#category_set = dict()
image_set = set()
image_id = 2022000001  #train:2018xxx; val:2019xxx; test:2020xxx
category_item_id = 0
annotation_id = 0
category_set = ['1','2',"0"]
'''
def addCatItem(name):
    global category_item_id
    category_item = dict()
    category_item['supercategory'] = 'none'
    category_item_id += 1
    category_item['id'] = category_item_id
    category_item['name'] = name
    coco['categories'].append(category_item)
    category_set[name] = category_item_id
    return category_item_id
'''
 
def addCatItem(name):
    '''
    增加json格式中的categories部分
    '''
    global category_item_id
    category_item = collections.OrderedDict()
    category_item['supercategory'] = 'none'
    category_item['id'] = category_item_id
    category_item['name'] = name
    coco['categories'].append(category_item)
    category_item_id += 1
 
def addImgItem(file_name, size):
    global image_id
    if file_name is None:
        raise Exception('Could not find filename tag in xml file.')
    if size['width'] is None:
        raise Exception('Could not find width tag in xml file.')
    if size['height'] is None:
        raise Exception('Could not find height tag in xml file.')
    #image_item = dict()    #按照一定的顺序,这里采用collections.OrderedDict()
    image_item = collections.OrderedDict()
    print(file_name,"*******")
    # jpg_name = os.path.splitext(file_name)[0]+'.png'
    jpg_name = file_name
    image_item['file_name'] = jpg_name  
    image_item['width'] = size['width']   
    image_item['height'] = size['height']
    image_item['id'] = image_id
    coco['images'].append(image_item)
    image_set.add(jpg_name)    
    image_id = image_id+1
    return image_id
 
 
def addAnnoItem(object_name, image_id, category_id, bbox):
    global annotation_id
    #annotation_item = dict()
    annotation_item = collections.OrderedDict()
    annotation_item['segmentation'] = []
    seg = []
    # bbox[] is x,y,w,h
    # left_top
    seg.append(bbox[0])
    seg.append(bbox[1])
    # left_bottom
    seg.append(bbox[0])
    seg.append(bbox[1] + bbox[3])
    # right_bottom
    seg.append(bbox[0] + bbox[2])
    seg.append(bbox[1] + bbox[3])
    # right_top
    seg.append(bbox[0] + bbox[2])
    seg.append(bbox[1])
    annotation_item['segmentation'].append(seg)
    annotation_item['area'] = bbox[2] * bbox[3]
    annotation_item['iscrowd'] = 0
    annotation_item['image_id'] = image_id
    annotation_item['bbox'] = bbox
    annotation_item['category_id'] = category_id
    annotation_item['id'] = annotation_id
    annotation_item['ignore'] = 0 
    annotation_id += 1 
    coco['annotations'].append(annotation_item)
 
 
def parseXmlFiles(xml_path):
    xmllist = os.listdir(xml_path)
    xmllist.sort()
    for f in xmllist:
        if not f.endswith('.xml'):
            continue
 
        bndbox = dict()
        size = dict()
        current_image_id = None
        current_category_id = None
        file_name = None
        size['width'] = None
        size['height'] = None
        size['depth'] = None
 
        xml_file = os.path.join(xml_path, f)
        print(xml_file)
 
        tree = ET.parse(xml_file)
        root = tree.getroot() #抓根结点元素
 
        if root.tag != 'annotation': #根节点标签
            raise Exception('pascal voc xml root element should be annotation, rather than {}'.format(root.tag))
 
        # elem is <folder>, <filename>, <size>, <object>
        for elem in root:
            current_parent = elem.tag
            current_sub = None
            object_name = None
 
            #elem.tag, elem.attrib,elem.text
            if elem.tag == 'folder':
                continue
 
            if elem.tag == 'filename':
                file_name = elem.text
                if file_name in category_set:
                    raise Exception('file_name duplicated')
 
            # add img item only after parse <size> tag
            elif current_image_id is None and file_name is not None and size['width'] is not None:
                if file_name not in image_set:
                    current_image_id = addImgItem(file_name, size)#图片信息
                    print('add image with {} and {}'.format(file_name, size))
                else:
                    raise Exception('duplicated image: {}'.format(file_name))
                    # subelem is <width>, <height>, <depth>, <name>, <bndbox>
            for subelem in elem:
                bndbox['xmin'] = None
                bndbox['xmax'] = None
                bndbox['ymin'] = None
                bndbox['ymax'] = None
 
                current_sub = subelem.tag
                if current_parent == 'object' and subelem.tag == 'name':
                    object_name = subelem.text
                    #if object_name not in category_set:
                    #    current_category_id = addCatItem(object_name)
                    #else:
                    #current_category_id = category_set[object_name]
                    current_category_id = category_set.index(object_name) #index默认从0开始,但是json文件是从1开始,所以+1
                elif current_parent == 'size':
                    if size[subelem.tag] is not None:
                        raise Exception('xml structure broken at size tag.')
                    size[subelem.tag] = int(subelem.text)
 
                # option is <xmin>, <ymin>, <xmax>, <ymax>, when subelem is <bndbox>
                for option in subelem:
                    if current_sub == 'bndbox':
                        if bndbox[option.tag] is not None:
                            raise Exception('xml structure corrupted at bndbox tag.')
                        bndbox[option.tag] = int(option.text)
 
                # only after parse the <object> tag
                if bndbox['xmin'] is not None:
                    if object_name is None:
                        raise Exception('xml structure broken at bndbox tag')
                    if current_image_id is None:
                        raise Exception('xml structure broken at bndbox tag')
                    if current_category_id is None:
                        raise Exception('xml structure broken at bndbox tag')
                    bbox = []
                    # x
                    bbox.append(bndbox['xmin'])
                    # y
                    bbox.append(bndbox['ymin'])
                    # w
                    bbox.append(bndbox['xmax'] - bndbox['xmin'])
                    # h
                    bbox.append(bndbox['ymax'] - bndbox['ymin'])
                    print(
                    'add annotation with {},{},{},{}'.format(object_name, current_image_id-1, current_category_id, bbox))
                    addAnnoItem(object_name, current_image_id-1, current_category_id, bbox)
    #categories部分
    for categoryname in category_set:
        addCatItem(categoryname) 
 
 
if __name__ == '__main__':
    # xml_path = 'VOC2007/train_annotations'
    xml_path = 'VOC2007/val_annotations'
    # json_file = 'coco/annotations/_instances_train2022.json'
    json_file = 'coco/annotations/_instances_val2022.json'
    parseXmlFiles(xml_path)
    json.dump(coco, open(json_file, 'w'))

4_dataset_test.py

'''数据集测试'''


import os
import cv2
import logging
from collections import OrderedDict

import detectron2.utils.comm as comm
from detectron2.utils.visualizer import Visualizer
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets.coco import load_coco_json
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.evaluation import COCOEvaluator, verify_results
from detectron2.modeling import GeneralizedRCNNWithTTA

# 数据集路径
DATASET_ROOT = './datasets/coco_17coco'
ANN_ROOT = os.path.join(DATASET_ROOT, 'annotations')

TRAIN_PATH = os.path.join(DATASET_ROOT, 'train2017')
VAL_PATH = os.path.join(DATASET_ROOT, 'val2017')

TRAIN_JSON = os.path.join(ANN_ROOT, 'instances_train2017.json')
# VAL_JSON = os.path.join(ANN_ROOT, 'val.json')
VAL_JSON = os.path.join(ANN_ROOT, 'instances_val2017.json')

CLASS_NAMES = ['line']
# 数据集类别元数据
DATASET_CATEGORIES = [
    # {"name": "background", "id": 0, "isthing": 1, "color": [220, 20, 60]},
    {"name": "clothing", "id": 0, "isthing": 1, "color": [220, 20, 60]},
    {"name": "face_shield", "id": 1, "isthing": 1, "color": [110, 20, 60]},
    {"name": "boot", "id": 2, "isthing": 1, "color": [220, 20, 30]},
    # {"name": "clothing_f", "id": 3, "isthing": 1, "color": [220, 10, 30]},

]

# 数据集的子集
PREDEFINED_SPLITS_DATASET = {
    "train_2019": (TRAIN_PATH, TRAIN_JSON),
    "val_2019": (VAL_PATH, VAL_JSON),
}


def register_dataset():
    """
    purpose: register all splits of dataset with PREDEFINED_SPLITS_DATASET
    """
    for key, (image_root, json_file) in PREDEFINED_SPLITS_DATASET.items():
        register_dataset_instances(name=key,
                                   metadate=get_dataset_instances_meta(),
                                   json_file=json_file,
                                   image_root=image_root)


def get_dataset_instances_meta():
    """
    purpose: get metadata of dataset from DATASET_CATEGORIES
    return: dict[metadata]
    """
    thing_ids = [k["id"] for k in DATASET_CATEGORIES if k["isthing"] == 1]
    thing_colors = [k["color"] for k in DATASET_CATEGORIES if k["isthing"] == 1]
    # assert len(thing_ids) == 2, len(thing_ids)
    thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
    thing_classes = [k["name"] for k in DATASET_CATEGORIES if k["isthing"] == 1]
    ret = {
        "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
        "thing_classes": thing_classes,
        "thing_colors": thing_colors,
    }
    return ret


def register_dataset_instances(name, metadate, json_file, image_root):
    """
    purpose: register dataset to DatasetCatalog,
             register metadata to MetadataCatalog and set attribute
    """
    DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
    MetadataCatalog.get(name).set(json_file=json_file,
                                  image_root=image_root,
                                  evaluator_type="coco",
                                  **metadate)


# 注册数据集和元数据
def plain_register_dataset():
    DatasetCatalog.register("train_2019", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH, "train_2019"))
    MetadataCatalog.get("train_2019").set(thing_classes=CLASS_NAMES,
                                          json_file=TRAIN_JSON,
                                          image_root=TRAIN_PATH)
    DatasetCatalog.register("val_2019", lambda: load_coco_json(VAL_JSON, VAL_PATH, "val_2019"))
    MetadataCatalog.get("val_2019").set(thing_classes=CLASS_NAMES,
                                        json_file=VAL_JSON,
                                        image_root=VAL_PATH)


# 查看数据集标注
def checkout_dataset_annotation(name="train_2019"):
    dataset_dicts = load_coco_json(TRAIN_JSON, TRAIN_PATH, name)
    for d in dataset_dicts:
        # img = cv2.imread(d["file_name"])
        # visualizer = Visualizer(img[:, :, ::-1], metadata=MetadataCatalog.get(name), scale=1.5)
        # vis = visualizer.draw_dataset_dict(d)
        # cv2.imshow('show', vis.get_image()[:, :, ::-1])
        # cv2.waitKey(0)
        print(d['file_name'])
        img = cv2.imread(d["file_name"])
        visualizer = Visualizer(img[:, :, ::-1], metadata=MetadataCatalog.get(name), scale=1.5)
        vis = visualizer.draw_dataset_dict(d)
        # # 显示方法一
        # cv2.namedWindow('findCorners', 0)
        # cv2.resizeWindow('findCorners', 1920, 1080)  # 自己设定窗口图片的大小
        # 显示方法二
        cv2.namedWindow('findCorners', cv2.WINDOW_NORMAL)
        cv2.imshow('findCorners', vis.get_image()[:, :, ::-1])
        cv2.waitKey(0)



register_dataset()

checkout_dataset_annotation()

5_trainsample.py

import os
import cv2
import logging
from collections import OrderedDict

import detectron2.utils.comm as comm
from detectron2.utils.visualizer import Visualizer
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets.coco import load_coco_json
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.evaluation import COCOEvaluator, verify_results
from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2 import model_zoo


# 数据集路径
DATASET_ROOT = '../datasets/cocococo'
ANN_ROOT = os.path.join(DATASET_ROOT, 'annotations')

TRAIN_PATH = os.path.join(DATASET_ROOT, 'train2017')
VAL_PATH = os.path.join(DATASET_ROOT, 'val2017')

TRAIN_JSON = os.path.join(ANN_ROOT, 'instances_train2017.json')
# VAL_JSON = os.path.join(ANN_ROOT, 'val.json')
VAL_JSON = os.path.join(ANN_ROOT, 'instances_val2017.json')

CLASS_NAMES = ['line']
# 数据集类别元数据
DATASET_CATEGORIES = [
    # {"name": "background", "id": 0, "isthing": 1, "color": [220, 20, 60]},
    {"name": "line", "id": 0, "isthing": 1, "color": [220, 20, 60]},
]

# 数据集的子集
PREDEFINED_SPLITS_DATASET = {
    "train_2019": (TRAIN_PATH, TRAIN_JSON),
    "val_2019": (VAL_PATH, VAL_JSON),
}


def register_dataset():
    """
    purpose: register all splits of dataset with PREDEFINED_SPLITS_DATASET
    """
    for key, (image_root, json_file) in PREDEFINED_SPLITS_DATASET.items():
        register_dataset_instances(name=key,
                                   metadate=get_dataset_instances_meta(),
                                   json_file=json_file,
                                   image_root=image_root)


def get_dataset_instances_meta():
    """
    purpose: get metadata of dataset from DATASET_CATEGORIES
    return: dict[metadata]
    """
    thing_ids = [k["id"] for k in DATASET_CATEGORIES if k["isthing"] == 1]
    thing_colors = [k["color"] for k in DATASET_CATEGORIES if k["isthing"] == 1]
    # assert len(thing_ids) == 2, len(thing_ids)
    thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
    thing_classes = [k["name"] for k in DATASET_CATEGORIES if k["isthing"] == 1]
    ret = {
        "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
        "thing_classes": thing_classes,
        "thing_colors": thing_colors,
    }
    return ret


def register_dataset_instances(name, metadate, json_file, image_root):
    """
    purpose: register dataset to DatasetCatalog,
             register metadata to MetadataCatalog and set attribute
    """
    DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
    MetadataCatalog.get(name).set(json_file=json_file,
                                  image_root=image_root,
                                  evaluator_type="coco",
                                  **metadate)


# 注册数据集和元数据
def plain_register_dataset():
    DatasetCatalog.register("train_2019", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH, "train_2019"))
    MetadataCatalog.get("train_2019").set(thing_classes=CLASS_NAMES,
                                          json_file=TRAIN_JSON,
                                          image_root=TRAIN_PATH)
    DatasetCatalog.register("val_2019", lambda: load_coco_json(VAL_JSON, VAL_PATH, "val_2019"))
    MetadataCatalog.get("val_2019").set(thing_classes=CLASS_NAMES,
                                        json_file=VAL_JSON,
                                        image_root=VAL_PATH)


class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        return COCOEvaluator(dataset_name, cfg, distributed=False, output_dir=output_folder)

    @classmethod
    def test_with_TTA(cls, cfg, model):
        logger = logging.getLogger("detectron2.trainer")
        # In the end of training, run an evaluation with TTA
        # Only support some R-CNN models.
        logger.info("Running inference with test-time augmentation ...")
        model = GeneralizedRCNNWithTTA(cfg, model)
        evaluators = [
            cls.build_evaluator(
                cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
            )
            for name in cfg.DATASETS.TEST
        ]
        res = cls.test(cfg, model, evaluators)
        res = OrderedDict({k + "_TTA": v for k, v in res.items()})
        return res


def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()  # 拷贝default config副本
    args.config_file = "../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"
    cfg.merge_from_file(args.config_file)  # 从config file 覆盖配置
    cfg.merge_from_list(args.opts)  # 从CLI参数 覆盖配置

    # 更改配置参数
    cfg.DATASETS.TRAIN = ("train_2019",)
    cfg.DATASETS.TEST = ("val_2019",)
    cfg.DATALOADER.NUM_WORKERS = 2  # 单线程
    # cfg.INPUT.MAX_SIZE_TRAIN = 400
    # cfg.INPUT.MAX_SIZE_TEST = 400
    # cfg.INPUT.MIN_SIZE_TRAIN = (160,)
    # cfg.INPUT.MIN_SIZE_TEST = 160
    cfg.MODEL.DEVICE = 'cpu'
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3  # 类别数
    # cfg.MODEL.WEIGHTS = "../model/COCO-Detection/faster_rcnn_R_50_FPN_3x/model_final_280758.pkl"  # 预训练模型权重
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
    "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # 预训练模型权重
    cfg.SOLVER.IMS_PER_BATCH = 6  # 6 batch_size=2; iters_in_one_epoch = dataset_imgs/batch_size
    ITERS_IN_ONE_EPOCH = int(340 / cfg.SOLVER.IMS_PER_BATCH)
    # (ITERS_IN_ONE_EPOCH * ) - 1 # 12 epochs
    cfg.SOLVER.MAX_ITER = 560
    cfg.SOLVER.BASE_LR = 0.002
    cfg.SOLVER.MOMENTUM = 0.9
    cfg.SOLVER.WEIGHT_DECAY = 0.0001
    cfg.SOLVER.WEIGHT_DECAY_NORM = 0.0
    cfg.SOLVER.GAMMA = 0.1
    cfg.SOLVER.STEPS = (500,)
    cfg.SOLVER.WARMUP_FACTOR = 1.0 / 1000
    cfg.SOLVER.WARMUP_ITERS = 300
    cfg.SOLVER.WARMUP_METHOD = "linear"
    cfg.SOLVER.CHECKPOINT_PERIOD = ITERS_IN_ONE_EPOCH - 1
    cfg.OUTPUT_DIR = "./output_2017_mask_rcnn_test111/"
    cfg.freeze()
    default_setup(cfg, args)
    return cfg


def main(args):
    cfg = setup(args)
    print(cfg)

    # 注册数据集
    register_dataset()

    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    return trainer.train()


if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )

猜你喜欢

转载自blog.csdn.net/m0_64118152/article/details/128819702