【pytorch】Mask-RCNN官方源码剖析(Ⅰ)

  • 代码仓:https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/

训练和推断源码部分

  1. train_net.py文件:
from maskrcnn_benchmark.utils.env import setup_environment  # noqa F401 isort:skip

# 常规包
import argparse
import os

import torch
from maskrcnn_benchmark.config import cfg # 导入默认的配置信息
from maskrcnn_benchmark.data import make_data_loader # 数据集的载入
from maskrcnn_benchmark.solver import make_lr_scheduler # 学习率的更新策略
from maskrcnn_benchmark.solver import make_optimizer # 设置优化器,封装了pytorch的SGD类
from maskrcnn_benchmark.engine.inference import inference # 推理代码
from maskrcnn_benchmark.engine.trainer import do_train # 模型训练的核心逻辑代码

# 用来创建目标检测模型的
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
# 分布式训练相关设置
from maskrcnn_benchmark.utils.comm import synchronize, get_rank
from maskrcnn_benchmark.utils.imports import import_file
# 日志情况
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config
  • 先来看主函数代码train_net.main()函数:
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
    # 这里虽然是config-file,但是是使用args.config_file来访问的,不能直接用args.config-file访问
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--skip-test",
        dest="skip_test",
        help="Do not test the final model",
        action="store_true",
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER, # 这一行不能少
    )

    args = parser.parse_args()
    
    # 获取GPU数量
    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    # 是否进行多GPU分布式训练
    args.distributed = num_gpus > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )
        synchronize()

    # 将config_file 指定的配置项覆盖到默认配置项当中
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze() # 冻结所有配置项,防止修改

    output_dir = cfg.OUTPUT_DIR
    if output_dir:
        mkdir(output_dir)

    # 相关信息的日志代码
    logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(args)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    logger.info("Loaded configuration file {}".format(args.config_file))
    # 打来指定的配置文件,并读取其中的相关信息,将值储存在 config_str 中,然后输出在屏幕上
    with open(args.config_file, "r") as cf:
        config_str = "\n" + cf.read()
        logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml')
    logger.info("Saving config into: {}".format(output_config_path))
    # save overloaded model config in the output directory
    save_config(cfg, output_config_path)

    # 调用 train 函数
    model = train(cfg, args.local_rank, args.distributed)

    # 该函数会执行模型推理
    if not args.skip_test:
        run_test(cfg, model, args.distributed)


if __name__ == "__main__":
    main()
  • train_net.trian()训练函数
def train(cfg, local_rank, distributed):
    # 该语句调用了 ./maskrcnn_benchmark/modeling/detector/ 中的build_detection_model()函数
    # 用来创建目标检测模型的
    # 该函数会根据我们的配置文件返回一个网络模型
    model = build_detection_model(cfg)

    # 默认为“cuda”
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device) # 将模型移到指定设备上
    

    # 封装了 torch.optim.SGD() 函数,根据tensor的requires——grad 属性构成需要更新的参数列表
    optimizer = make_optimizer(cfg, model)
    # 根据配置信息设置 optimizer 的学习率更新策略
    scheduler = make_lr_scheduler(cfg, optimizer)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.DTYPE == "float16"
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)

    # 分布式训练的情况下,并行处理数据
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

   # 创建一个参数字典,并将迭代次数置为0
    arguments = {
    
    }
    arguments["iteration"] = 0
    # 获取输出文件夹的路径
    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    # detectroncheckpointer 对象,后面会用在 do_train 函数的参数
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    # 加载指定权重文件
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    # 字典的update方法,对字典的键值进行更新
    arguments.update(extra_checkpoint_data)

    # data_loader 的类型为列表, 内部元素为 torch.utils.data.DataLoader
    # 注意,当is_train = true时,要确保cfg.DATASETS.TRAIN 的值为一个列表
    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    test_period = cfg.SOLVER.TEST_PERIOD
    if test_period > 0:
        data_loader_val = make_data_loader(cfg, is_train=False, is_distributed=distributed, is_for_period=True)
    else:
        data_loader_val = None

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        cfg,
        model,
        data_loader,
        data_loader_val,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        test_period,
        arguments,
    )

    return model

该函数的核心部分在于调用了./maskrcnn_benchmark/engine/trainer.py文件中的do_train()函数,该函数解析如下:

# 导入各种包以及函数
import datetime
import logging
import os
import time

import torch
import torch.distributed as dist # 分布式相关
from tqdm import tqdm

from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.utils.comm import get_world_size, synchronize
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
from maskrcnn_benchmark.engine.inference import inference

from apex import amp

# 计算 reduce 后的损失函数字典
def reduce_loss_dict(loss_dict):
    """
    Reduce the loss dictionary from all processes so that process with rank
    0 has the averaged results. Returns a dict with the same fields as
    loss_dict, after reduction.
    对loss进行reduce,像是mapreduce一样,专门利用rank0进行不同gpu的数据统一处理
    """
    world_size = get_world_size()
    if world_size < 2: # 单个GPU,直接返回,无需reduce
        return loss_dict
    with torch.no_grad(): # 不需要计算任何参数的梯度
        loss_names = []
        all_losses = []
        for k in sorted(loss_dict.keys()):
            loss_names.append(k) # 获取键
            all_losses.append(loss_dict[k]) # 获取值

        # 将列表中的loss连接起来组成一个一维的tensor,tensor的每个元素代表一个loss
        all_losses = torch.stack(all_losses, dim=0)
        dist.reduce(all_losses, dst=0)
        if dist.get_rank() == 0:
            # only main process gets accumulated, so only divide by
            # world_size in this case
            all_losses /= world_size
        reduced_losses = {
    
    k: v for k, v in zip(loss_names, all_losses)}
    return reduced_losses


# 模型训练的核心代码
def do_train(
    cfg,
    model, # 从 build_detection_model 函数得到的模型对象
    data_loader, # pytorch 的 dataloader 对象,对应相应的数据集
    data_loader_val,
    optimizer,  # torch.optim.sgd.SGD 对象
    scheduler, # 学习率的更新策略, 封装在 solver/lr_scheduler.py文件中
    checkpointer, # detectroncheckpointer,用于自动转换caffe2 detectron的模型文件
    device, # 指定训练设备
    checkpoint_period, # 指定模型的保存迭代间隔,默认为2500
    test_period, # 测试的迭代间隔
    arguments, # 额外的其他参数,字典类型,一般情况下只有 arguments[iteration],初值为0
):
    # 记录日志信息
    logger = logging.getLogger("maskrcnn_benchmark.trainer")
    logger.info("Start training")

    # 用于记录一些变量的滑动平局值和全局平均值
    meters = MetricLogger(delimiter="  ") # delimeter为定界符,此处用两个空格做为定界符
    
    # 数据载入器重写了 len 函数,使其返回载入器需要提供batch的次数,即 cfg.SOLVER.max_iter
    max_iter = len(data_loader)
    start_iter = arguments["iteration"] # 默认为0,但是会根据载入的权重文件,变成其他值
    model.train() # 将 model 的模式置为train,train() 函数的参数mode 默认值为true
    start_training_time = time.time()
    end = time.time() # 计时

    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm",)
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints",)
    dataset_names = cfg.DATASETS.TEST

    # 遍历 data_loader
    # data_loader  的返回值是(image,targets,shape)
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        
        if any(len(target) < 1 for target in targets):
            logger.error(f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" )
            continue
        data_time = time.time() - end # 获取一个batch 所需要的时间
        iteration = iteration + 1
        arguments["iteration"] = iteration

        images = images.to(device) # 把 images 移动到指定设备上
        targets = [target.to(device) for target in targets] # 移动到指定设备上

        loss_dict = model(images, targets) # 根据images 和 targets计算loss
        # 将各个loss 合并
        losses = sum(loss for loss in loss_dict.values()) 
   
        # 根据GPU数量对loss 进行reduce
        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        # 合并使用rank 0 reduce之后的loss
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced) # 更新滑动平均值

        optimizer.zero_grad() # 清除梯度缓存
        # Note: If mixed precision is not used, this ends up doing nothing
        # Otherwise apply loss scaling for mixed-precision recipe
        with amp.scale_loss(losses, optimizer) as scaled_losses:
            scaled_losses.backward() # 计算梯度
        optimizer.step() # 更新参数
        scheduler.step() # 更新一次学习率

        batch_time = time.time() - end # 进行一次batch 所需要的时间
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        # 根据时间的滑动平均值计算大约还剩多长时间结束训练
        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        # 每经过20次迭代,输出一次训练状态
        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                )
            )
        # 每经过checkpoint_period 次迭代后,就将模型保存
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        # 验证
        if data_loader_val is not None and test_period > 0 and iteration % test_period == 0:
            meters_val = MetricLogger(delimiter="  ")
            synchronize()
            _ = inference(  # The result can be used for additional logging, e. g. for TensorBoard
                model,
                # The method changes the segmentation mask format in a data loader,
                # so every time a new data loader is created:
                make_data_loader(cfg, is_train=False, is_distributed=(get_world_size() > 1), is_for_period=True),
                dataset_name="[Validation]",
                iou_types=iou_types,
                box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
                device=cfg.MODEL.DEVICE,
                expected_results=cfg.TEST.EXPECTED_RESULTS,
                expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
                output_folder=None,
            )
            synchronize()
            model.train()
            with torch.no_grad():
                # Should be one image for each GPU:
                for iteration_val, (images_val, targets_val, _) in enumerate(tqdm(data_loader_val)):
                    images_val = images_val.to(device)
                    targets_val = [target.to(device) for target in targets_val]
                    loss_dict = model(images_val, targets_val)
                    losses = sum(loss for loss in loss_dict.values())
                    loss_dict_reduced = reduce_loss_dict(loss_dict)
                    losses_reduced = sum(loss for loss in loss_dict_reduced.values())
                    meters_val.update(loss=losses_reduced, **loss_dict_reduced)
            synchronize()
            logger.info(
                meters_val.delimiter.join(
                    [
                        "[Validation]: ",
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters_val),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                )
            )
        # 达到最大迭代次数后,也进行保存
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)
    # 输出总的训练耗时
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter)
        )
    )
  • train_net.test() 推理函数
def run_test(cfg, model, distributed):
    if distributed:
        model = model.module
    torch.cuda.empty_cache()  # TODO check if it helps 释放未被占用的内存
    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON: # 如果mask为true,则添加分割信息
        iou_types = iou_types + ("segm",)
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints",)
    output_folders = [None] * len(cfg.DATASETS.TEST) # 根据标签文件数确定输出文件夹数
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names): # 遍历标签文件
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
            mkdir(output_folder) # 创建出输出文件夹
            output_folders[idx] = output_folder # 将文件夹的路径名放入列表
    # 根据配置文件信息创建数据集
    data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
    # 遍历每个标签文件,执行 inference 过程
    for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
        inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
            bbox_aug=cfg.TEST.BBOX_AUG.ENABLED,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize() # 多GPU推理时的同步函数

在执行模型的推理逻辑时,函数调用了./maskrcnn_benchmark/engine/inference.py文件中的inference()函数,该函数的分析如下:

# 模型推理过程的逻辑
import logging
import time
import os

import torch
from tqdm import tqdm

# 导入评价函数
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
from ..utils.comm import synchronize
from ..utils.timer import Timer, get_time_str
from .bbox_aug import im_detect_bbox_aug

# 计算结果,获得预测结果
def compute_on_dataset(model, data_loader, device, bbox_aug, timer=None):
    model.eval() # 将模型的状态置为eval,主要影响 dropout、bn等操作的行为
    results_dict = {
    
    }
    cpu_device = torch.device("cpu")
    for _, batch in enumerate(tqdm(data_loader)):
        images, targets, image_ids = batch
        with torch.no_grad(): # 使用model运算时,不需要计算梯度
            if timer:
                timer.tic()
            if bbox_aug:
                output = im_detect_bbox_aug(model, images, device)
            else:
                output = model(images.to(device)) # 计算
            if timer:
                if not device.type == 'cpu':
                    torch.cuda.synchronize()
                timer.toc()
            output = [o.to(cpu_device) for o in output] # 将计算结果移到cpu上
        # 更新结果字典
        results_dict.update(
            {
    
    img_id: result for img_id, result in zip(image_ids, output)}
        )
    return results_dict

# 累积预测
def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu):
    all_predictions = all_gather(predictions_per_gpu)
    if not is_main_process():
        return
    # merge the list of dicts
    predictions = {
    
    }
    for p in all_predictions:
        predictions.update(p)
    # convert a dict where the key is the index in a list
    image_ids = list(sorted(predictions.keys()))
    if len(image_ids) != image_ids[-1] + 1:
        logger = logging.getLogger("maskrcnn_benchmark.inference")
        logger.warning(
            "Number of images that were gathered from multiple processes is not "
            "a contiguous set. Some images might be missing from the evaluation"
        )

    # convert to a list
    predictions = [predictions[i] for i in image_ids]
    return predictions

# 模型推理
def inference(
        model,  # 从 build_detection_model 函数得到模型对象
        data_loader, # pytorch的 dataloader 对象,对应自定义的数据集
        dataset_name, # str, 数据集的名字
        iou_types=("bbox",), # iou的类型,默认为bbox
        box_only=False, # cfg.MODEL.RPN_ONLY='false'
        bbox_aug=False,
        device="cuda",
        expected_results=(),
        expected_results_sigma_tol=4,
        output_folder=None, # 自定义输出文件夹
):
    # convert to a torch.device for efficiency
    # 获取设备
    device = torch.device(device)
    num_devices = get_world_size() # 设备数
    # 日志信息
    logger = logging.getLogger("maskrcnn_benchmark.inference")
    # 数据集
    dataset = data_loader.dataset
    logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
    # 计时
    total_timer = Timer()
    inference_timer = Timer()
    total_timer.tic()
    # 调用本文件的函数,获取预测结果
    predictions = compute_on_dataset(model, data_loader, device, bbox_aug, inference_timer)
    # wait for all processes to complete before measuring the time
    # 等到所有的进程都结束以后再计算总耗时
    synchronize()
    # 计算总耗时计入log
    total_time = total_timer.toc()
    total_time_str = get_time_str(total_time)
    logger.info(
        "Total run time: {} ({} s / img per device, on {} devices)".format(
            total_time_str, total_time * num_devices / len(dataset), num_devices
        )
    )
    total_infer_time = get_time_str(inference_timer.total_time)
    logger.info(
        "Model inference time: {} ({} s / img per device, on {} devices)".format(
            total_infer_time,
            inference_timer.total_time * num_devices / len(dataset),
            num_devices,
        )
    )
    # 调用函数,将所有GPU设备上的预测结果累加并且返回
    predictions = _accumulate_predictions_from_multiple_gpus(predictions)
    if not is_main_process():
        return

    if output_folder:
    # 将结果保存
        torch.save(predictions, os.path.join(output_folder, "predictions.pth"))

    extra_args = dict(
        box_only=box_only,
        iou_types=iou_types,
        expected_results=expected_results,
        expected_results_sigma_tol=expected_results_sigma_tol,
    )

    # 调用评价函数,返回预测结果的质量
    return evaluate(dataset=dataset,
                    predictions=predictions,
                    output_folder=output_folder,
                    **extra_args)
  1. test_net.py文件:
from maskrcnn_benchmark.utils.env import setup_environment  # noqa F401 isort:skip

import argparse
import os

import torch
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.engine.inference import inference
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize, get_rank
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir

# Check if we can enable mixed-precision via apex.amp
try:
    from apex import amp
except ImportError:
    raise ImportError('Use APEX for mixed precision via apex.amp')


def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference")
    # 配置文件路径
    parser.add_argument(
        "--config-file",
        default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml",
        metavar="FILE",
        help="path to config file",
    )
    # local_rank
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--ckpt",
        help="The path to the checkpoint for test, default is the latest checkpoint.",
        default=None,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    
    # 获取gpu数目
    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    # 根据gpu数目设置distributed布尔变量
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://"
        )
        synchronize()
    # 将指定的配置文件的设置覆盖到全局设置中
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze() # 冻结配置信息,防止更改

    save_dir = ""
    logger = setup_logger("maskrcnn_benchmark", save_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    # 根据配置信息创建模型
    model = build_detection_model(cfg)
    # 将模型移动到指定设备上
    model.to(cfg.MODEL.DEVICE)

    # Initialize mixed-precision if necessary
    use_mixed_precision = cfg.DTYPE == 'float16'
    amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

    # 获取输出文件夹父路径
    output_dir = cfg.OUTPUT_DIR
    # 加载权重
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
    ckpt = cfg.MODEL.WEIGHT if args.ckpt is None else args.ckpt
    _ = checkpointer.load(ckpt, use_latest=args.ckpt is None)

    # 设置iou类型
    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm",)
    if cfg.MODEL.KEYPOINT_ON:
        iou_types = iou_types + ("keypoints",)

    # 根据数据集的数量定义输出文件夹
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    # 创建输出文件夹
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder
    # 加载测试数据集
    data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed)
    # 对数据集中的数据按批次调用inference函数
    for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val):
        inference(
            model,
            data_loader_val,
            dataset_name=dataset_name,
            iou_types=iou_types,
            box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
            bbox_aug=cfg.TEST.BBOX_AUG.ENABLED,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
            output_folder=output_folder,
        )
        synchronize()


if __name__ == "__main__":
    main()

猜你喜欢

转载自blog.csdn.net/qq_43348528/article/details/107311771
今日推荐