目标检测Detectron2源码解析

目录

第一章    架构设计与实现

第二章    网络结构设计  --Detectron2\detectron2\modeling

1. backbone

2. proposal generator

3. roi_heads

4. meta-architectures

第三章   训练和推理 -- Detectron2\detectron2\engine

1. 接口

2. 实现

2.1. 训练类的实现

2.2. 辅助功能类的实现

3. 总执行程序

3.1. 多层抽象

3.2. 直接实现

4. 推理

第四章  性能评估 -- Detectron2\detectron2\evaluation

1. 接口

2. 实现

第五章 Data

1. 数据集概述

2. 制作读入数据集Datasets的API

3. 将数据集转化为可供模型使用的数据

4. 数据集的调用


Detectron2是深度学习目标检测算法的平台,基于PyTorch框架,集成了多种深度学习目标检测算法,支持instance detection, instance\semantic\panoptic segmentation, keypoint detection功能。

Detectron2与Pytorch的关系及其作用如下:

                  Projects                               用户可以基于detectron2修改扩展,实现新的网络

 _____________________________

                                                              1. 搭建Trainer、Evaluation的框架,并实现DataLoader、Logging、Checkpoints等辅助功能;

               Detectron2                               2. 实现常用的目标检测网络。

                                                              其中,训练部分solver和基础的CNN层使用pytorch。

_____________________________

                  Pytorch                                 实现基础的CNN层

Detectron2开源代码链接:https://github.com/facebookresearch/detectron2

Detectron2的使用说明文档:https://detectron2.readthedocs.io/index.html

本文主要讲解Detectron2的架构和使用方法。

第一章    架构设计与实现

Detectron2工程的目录树如下:

Detectron2\detectron2
├── engine                        训练train和test主体流程

├── evaluation                  评估数据集上性能

├── data                           载入数据

├── modeling                    深度学习网络结构

├── layers                         对Pytorch定义的层做了封装(wrapper)和扩展(实现了一些非Pytorch内置的层:ASPP等),是构成modeling模块的基本层

├── solver                         求解器

├── checkpoint                 存储和加载模型权重
├── config                        读取配置文件
├── export
├── model_zoo
├── projects                     工程示例
├── structures
└── utils

engine、data、modeling(即model)、solver(即optimizer)各模块关系:

  • engine作为串起数据流的逻辑,调用data、model、optimizer各模块,实现训练、推理;
  • data作为model的输入;若为推理,model还需要载入训练好的weight。
  • model作为optimizer的输入。

evaluation实现推理结果的统计指标评估。

另外,还有一些软件需要的功能,checkpoint实现断点存储及载入,config实现解析配置文件。

训练需要的输入是数据集(包括图像和真值)、预训练模型、配置文件,输出是训练的权重(和log),训练工作流程:

制作数据集Datasets=>载入数据DataLoaders=>

编写模型Models                                          =>     

设置求解器optimizer=>设置learning rate      =>         

开始迭代:进行前向传播和反向传播 ,关键流程如下:

        data = next(self._data_loader_iter)
        loss_dict = self.model(data)
        losses = sum(loss_dict.values())   #求各个head的loss的和
        self.optimizer.zero_grad()  #将所有参数的梯度都置零
        losses.backward()           #误差反向传播计算参数梯度
        self.optimizer.step()       #通过梯度做一步参数更新

推理需要的输入是图片、训练好的权重、配置文件,输出是图像可视化结果以及推理后的结果(以供性能评估使用),推理工作流程:

加载model => model加载模型weight:load_state_dict (初始化时加载一次) =>

开始迭代:加载图像和图像预处理    =>  图像输入model进行推理 (每次调用时输入一张图像,进行迭代)   => 保存图像推理后的可视化结果(包含图像) =>

完成迭代后,将感知结果统一写入文件(不包含图像)。

评估推理结果需要的输入是数据集的真值和推理的结果。

第二章    网络结构设计  --Detectron2\detectron2\modeling

torch/nn/modules/module.py中定义了class Module类,作为所有模块的基类(每个模块中含有很多层),多个Module可以连接起来,形成网络model。Module可以是各个尺度的,整个网络是一个Module;稍小尺度的,backbone会作为一个Module,不同head各自会作为Module;更小尺度的,backbone内部重复的block也作为Module。

meta-architectures类,即整个网络结构,包含如下3个大模块,即backbone,proposal generator,roi_heads,它们都继承自nn.Module,其中:

  • backbone: 负责extract feature maps from images。
  • proposal generator: generates proposals using backbone features. For example, RPN heads take feature maps and perform objectness classification and bounding box regression for anchors.
  • roi_heads: a ROI head that performs per-region computation

meta-architectures中实现forward,给出各个大模块计算链路的关系。每个大模块内部包含多个layers, Module中的def forward(self, *input)接口需要开发者定义layers层之间连接的顺序。在整个网络的forward方法中,训练过程要加入loss,并把标注真值转化为loss的feature层一样的格式;推理过程中,无需loss,一般需要加上网络的后处理,把推理出来的feature处理为最终需要的数据。

为了保证可以根据配置文件写的结构名称调用不同网络,以及通过配置文件名称调用不同的大模块组合成不同的meta-architectures,需要定义各个尺度Module的基类接口,保证接口的统一。另外,Detectron2中使用Registry方法,在定义网络结构时都调用Registry方法,使用时,再次调用注册器,根据配置文件中的名称加载相应的结构。从而,保证可以根据配置文件内的参数调用不同的meta-architecture,以及meta-architecture中根据配置文件的参数组合不同的backbone、proposal generator和roi_heads。其中,Registry定义在另一个工具包中fvcore.common.registry。

1. backbone

1.1. backbone接口

Detectron2/detectron2/modeling/backbone/backbone.py为网络backbone的抽象基类。

class Backbone(nn.Module, metaclass=ABCMeta):
    @abstractmethod
    def forward(self):  #必须实现forward接口
        """
        Subclasses must override this method, but adhere to the same return type.

        Returns:
            dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
        """
        pass

1.2. block接口

一个大的backbone由多个stages构成(in paper `FPN`: Usually, layers that produce the same feature map spatial size are defined as one "stage"),每个stages包含多个重复的blocks结构。

Detectron2/detectron2/layers/blocks.py定义了CNNBlockBase。Block由基础的Conv2d、relu_等构成,Detectron2/detectron2/layers/wrappers.py扩展了常用的层class Conv2d(torch.nn.Conv2d),class ConvTranspose2d(torch.nn.ConvTranspose2d),class BatchNorm2d(torch.nn.BatchNorm2d)。

class CNNBlockBase(nn.Module)  #其中,class Module(object)是Pytorch的Base class for all neural network modules.

1.3. block和backbone实现

Detectron2/detectron2/modeling/backbone/resnet.py中继承CNNBlockBase,定义了不同的Block。

class BasicBlock(CNNBlockBase)
class BottleneckBlock(CNNBlockBase)
class DeformBottleneckBlock(CNNBlockBase)
class BasicStem(CNNBlockBase)

 ResNet继承Backbone,并使用上面定义的block实现ResNet的backbone:

class ResNet(Backbone)

1.4. 注册与调用

在定义时:

@BACKBONE_REGISTRY.register()
def build_resnet_backbone(cfg, input_shape):
    return ResNet(stem, stages, out_features=out_features).freeze(freeze_at)

调用时,Detectron2/detectron2/modeling/backbone/build.py中根据配置文件名调用之前Register好的backbone:

    backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape)

Detectron2/detectron2/modeling/backbone/fpn.py文件又把build_resnet_backbone生产的resnet作为子结构输入,扩展了不同的FPN的backbone:

@BACKBONE_REGISTRY.register()
def build_resnet_fpn_backbone(cfg, input_shape: ShapeSpec):

以及

@BACKBONE_REGISTRY.register()
def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec):

2. proposal generator

@RPN_HEAD_REGISTRY.register()
class StandardRPNHead(nn.Module):

@PROPOSAL_GENERATOR_REGISTRY.register()
class RPN(nn.Module):

使用时,在Detectron2/detectron2/modeling/proposal_generator/build.py中同样根据配置文件调用相应的proposal generator:

PROPOSAL_GENERATOR_REGISTRY.get(name)(cfg, input_shape)

3. roi_heads

3.1.接口

Detectron2/detectron2/modeling/roi_heads/roi_heads.py

class ROIHeads(torch.nn.Module):

3.2. 实现

@ROI_HEADS_REGISTRY.register()
class Res5ROIHeads(ROIHeads):

@ROI_HEADS_REGISTRY.register()
class StandardROIHeads(ROIHeads):

其他不同功能的heads还包括mask_head、keypoint_head等。

4. meta-architectures

Detectron2/detectron2/modeling/meta_arch/rcnn.py中forward方法给出了数据流:

class:`DatasetMapper`的输出结果batched_inputs进行预处理images = self.preprocess_image(batched_inputs)

=》输入给backbone提取features = self.backbone(images.tensor)

=》将feature和图像作为输入给到proposal_generator模块proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)

=》将propose的结果给到roi_heads _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)。

@META_ARCH_REGISTRY.register()
class GeneralizedRCNN(nn.Module):
    def forward(self, batched_inputs):
        if not self.training:
            return self.inference(batched_inputs)

        images = self.preprocess_image(batched_inputs)
        if "instances" in batched_inputs[0]:
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
        else:
            gt_instances = None

        features = self.backbone(images.tensor)

        if self.proposal_generator:
            proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
        else:
            assert "proposals" in batched_inputs[0]
            proposals = [x["proposals"].to(self.device) for x in batched_inputs]
            proposal_losses = {}

        _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
        if self.vis_period > 0:
            storage = get_event_storage()
            if storage.iter % self.vis_period == 0:
                self.visualize_training(batched_inputs, proposals)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)
        return losses

第三章   训练和推理 -- Detectron2\detectron2\engine

1. 接口

在Detectron2\detectron2\engine\train_loop.py路径下定义了TrainerBase和HookBase基类。

HookBase定义了工作流水中4个重要部分:before_train、after_train、before_step、after_step。辅助功能类都继承自HookBase。

TrainerBase为控制整个工作pipeline的基类,其中,对Hook进行注册和按照工作pipeline进行调用,从而使得开发者可以灵活地定义和使用各种辅助功能类(例如,计算每步迭代的周期、在一定的迭代周期后评估结果、)。

class HookBase:    
    """
    Base class for hooks that can be registered with :class:`TrainerBase`.

    Each hook can implement 4 methods. The way they are called is demonstrated
    in the following snippet:
    ::
        hook.before_train()
        for iter in range(start_iter, max_iter):
            hook.before_step()
            trainer.run_step()
            hook.after_step()
        hook.after_train()
    """

    def before_train(self):
        """
        Called before the first iteration.
        """
        pass

    def after_train(self):
        """
        Called after the last iteration.
        """
        pass

    def before_step(self):
        """
        Called before each iteration.
        """
        pass

    def after_step(self):
        """
        Called after each iteration.
        """
        pass

TrainerBase中定义了register_hooks的方法注册定义的各种辅助功能类,以及train控制整个pipeline, 并预留了核心的训练部分run_step的接口:

class TrainerBase:
    """
    Base class for iterative trainer with hooks.

    The only assumption we made here is: the training runs in a loop.
    A subclass can implement what the loop is.
    We made no assumptions about the existence of dataloader, optimizer, model, etc.

    Attributes:
        iter(int): the current iteration.

        start_iter(int): The iteration to start with.
            By convention the minimum possible value is 0.

        max_iter(int): The iteration to end training.

        storage(EventStorage): An EventStorage that's opened during the course of training.
    """

    def __init__(self):
        self._hooks = []

    def register_hooks(self, hooks):  #1. 注册用户需要的功能类(功能类继承自HookBase),所谓注册就是将需要的功能类添加到列表中self._hooks = [];
        """
        Register hooks to the trainer. The hooks are executed in the order
        they are registered.

        Args:
            hooks (list[Optional[HookBase]]): list of hooks
        """
        hooks = [h for h in hooks if h is not None]
        for h in hooks:
            assert isinstance(h, HookBase)
            # To avoid circular reference, hooks and trainer cannot own each other.
            # This normally does not matter, but will cause memory leak if the
            # involved objects contain __del__:
            # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
            h.trainer = weakref.proxy(self)
        self._hooks.extend(hooks)

    def train(self, start_iter: int, max_iter: int):
        """
        Args:
            start_iter, max_iter (int): See docs above
        """
        logger = logging.getLogger(__name__)
        logger.info("Starting training from iteration {}".format(start_iter))

        self.iter = self.start_iter = start_iter
        self.max_iter = max_iter

        with EventStorage(start_iter) as self.storage:  #2. 按照流水线调用
            try:
                self.before_train()
                for self.iter in range(start_iter, max_iter):
                    self.before_step()
                    self.run_step()
                    self.after_step()
            except Exception:
                logger.exception("Exception during training:")
                raise
            finally:
                self.after_train()

    #3. 每一个方法会依次调用_hooks中的注册的功能类的相应方法
    def before_train(self):
        for h in self._hooks:
            h.before_train()

    def after_train(self):
        for h in self._hooks:
            h.after_train()

    def before_step(self):
        for h in self._hooks:
            h.before_step()

    def after_step(self):
        for h in self._hooks:
            h.after_step()
        # this guarantees, that in each hook's after_step, storage.iter == trainer.iter
        self.storage.step()

    #4. 实现主体的训练过程
    def run_step(self):   
        raise NotImplementedError

2. 实现

2.1. 训练类的实现

SimpleTrainer (Detectron2\detectron2\engine\train_loop.py路径下)继承自TrainerBase,对TrainerBase中预留接口的训练核心部分的方法def run_step(self)做了具体实现,包括推理计算loss以及backward:

class SimpleTrainer(TrainerBase):   #(class:`SimpleTrainer`, only does minimal SGD training and nothing else)
    def __init__(self, model, data_loader, optimizer):  #以model,data_loader, optimizer作为输入
        super().__init__()
        model.train()
        self.model = model
        self.data_loader = data_loader
        self._data_loader_iter = iter(data_loader)
        self.optimizer = optimizer

    def run_step(self):  #仅展示关键步骤
        data = next(self._data_loader_iter) 
        loss_dict = self.model(data)            #data作为model的输入
        losses = sum(loss_dict.values())        #1. Compute the loss with a data from the data_loader.
        self.optimizer.zero_grad()              #2. Compute the gradients with the above loss.
        losses.backward()
        # use a new stream so the ops don't wait for DDP
        with torch.cuda.stream(
            torch.cuda.Stream()
        ) if losses.device.type == "cuda" else _nullcontext():
            metrics_dict = loss_dict
            metrics_dict["data_time"] = data_time
            self._write_metrics(metrics_dict)
            self._detect_anomaly(losses, loss_dict)
        self.optimizer.step()                   #3. Update the model with the optimizer.

DefaultTrainer(Detectron2\detectron2\engine\defaults.py路径下)继承自SimpleTrainer,实现了训练流程,包括创建model, optimizer, scheduler, dataloader,根据配置文件增加了辅助功能hooks类中的功能

class DefaultTrainer(SimpleTrainer):
    def __init__(self, cfg):                  #1. Create model, optimizer, scheduler, dataloader from the given config.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)  #model作为optimizer的输入
        data_loader = self.build_train_loader(cfg)
        super().__init__(model, data_loader, optimizer)
        self.scheduler = self.build_lr_scheduler(cfg, optimizer)

    def resume_or_load(self, resume=True):    #2. 增加resume_or_load方法,载入权重或之前的checkpoint继续训练.

    def build_hooks(self):                    #3. Register a few common hooks defined by the config:default configurations for optimizer, learning rate schedule, logging, evaluation, checkpointing etc.
        ret = [
            hooks.IterationTimer(),
            hooks.LRScheduler(self.optimizer, self.scheduler),
            hooks.PreciseBN(
                # Run at the same freq as (but before) evaluation.
                cfg.TEST.EVAL_PERIOD,
                self.model,
                # Build a new data loader to not affect training
                self.build_train_loader(cfg),
                cfg.TEST.PRECISE_BN.NUM_ITER,
            )
            if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
            else None,
        ]
        if comm.is_main_process():
            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
        ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
        if comm.is_main_process():
            # run writers in the end, so that evaluation metrics are written
            ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
        return ret

2.2. 辅助功能类的实现

Detectron2\detectron2\engine\hooks.py 中定义了功能类的实现,这些功能类都继承自HookBase类。

功能类包括:

  • CallbackHook,
  • IterationTimer,  用来计算每个迭代花费的时间;
  • PeriodicWriter, 周期性输出log
  • PeriodicCheckpointer,  周期性存储模型
  • LRScheduler,
  • AutogradProfiler,
  • EvalHook,  周期地评估训练的性能;
  • PreciseBN

3. 总执行程序

Detectron2\tools层脚本提供了训练的主函数入口。

3.1. 多层抽象

train_net.py中采用上述的high-level abstractions,层层抽象,在之前TrainBase->SimpleTrainer->DefaultTrainer基础上又增加了一层抽象,添加evaluation模块的功能,以及inference with test-time augmentation功能:

class Trainer(DefaultTrainer):
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        
    def test_with_TTA(cls, cfg, model):

main函数中实例化trainer = Trainer(cfg),并调用类的方法进行训练:trainer.train(),实现“standard default” behavior。

Detectron2\projects中提供了使用detectron2搭建的各种工程示例:DeepLab、DensePose、PointRend、TensorMask、TridentNet,每个工程都有对应的train_net.py实现。

3.2. 直接实现

另外,tools还提供了一种thin abstractions,plain_train_net.py中没有使用以上抽象,跳过了TrainBase->SimpleTrainer->DefaultTrainer,直接罗列了流程(创建model, optimizer, scheduler, dataloader,实现了训练的流程),以便用户快速实现较基本的训练和推理流程。

4. 推理

推理类class DefaultPredictor也定义在Detectron2\detectron2\engine\defaults.py中,相比于训练过程简单很多,只需要加载模型、设定为推理模式、加载图像、进行推理。

第四章  性能评估 -- Detectron2\detectron2\evaluation

在训练过程中,需要定周期地给出评估结果;在推理后需要给出统计指标,评估模型的总体性能。

Evalution, 简而言之,就是将存储好的推理结果和真值做比较,计算统计量。

1. 接口

class DatasetEvaluator类定义在Detectron2\detectron2\evaluation\evaluator.py脚本中,包括def process(self, inputs, outputs)方法用于处理一对真值和预测值,在process完所有的真值和预测值后,调用def evaluate(self)方法计算统计指标。

对于多任务的网络,需要继承多个DatasetEvaluator评测各个head的指标,接口class DatasetEvaluators(DatasetEvaluator)分别运行各个DatasetEvaluator。

2. 实现

示例中给出了cityscapes_evaluation.py 、coco_evaluation.py等数据集评估的具体实现。

第五章 Data

1. 数据集概述

数据集的内容包括标注数据和元数据:

标注数据:原始图片、标注的结构化数据(image file name、bbox、category_id)。

元数据:描述数据的的数据(data about data),主要是描述数据属性的信息(thing_classes每个类别category_id对应的类别名称等)。

在准备好数据集后,为了使用数据集进行深度学习的训练和推理,首先,需要制作读入数据集的API,即根据数据集名称,获取数据集图像和标注对应关系的列表。这个数据集的列表是一种轻量化的对数据集的描述。detectron2支持经典的Builtin数据集,例如 coco\ lvis instance segmentaion\ cityscapes \ VOC20{07,12} \ADE20k Scene Parsing,开发者也可以自定义数据集。

随后,利用数据集的列表作为输入,将数据集转化为可供模型使用的数据,该过程包括将数据载入内存、进行数据增强、转化为batched的torch Tensors,并返回可供深度学习run_step方法中迭代载入的的迭代器。

下面对制作数据集的接口、载入数据集,以及使用数据集进行解释。

2. 制作读入数据集Datasets的API

准备好数据集后需要写载入数据集的API。采用注册机制,建立数据集名称name和读取数据集的函数func的对应关系,从而,可以根据配置文件参数选择载入不同的数据集。载入数据集指的是载入图像和标注数据,即读取数据集的函数func需要返回list[dict]。

注册数据集和调用数据集的接口定义在 Detectron2/detectron2/data/catalog.py中,不同的数据集可以使用类class _DatasetCatalog进行数据集的注册:

class _DatasetCatalog(UserDict):
    """
    A global dictionary that stores information about the datasets and how to obtain them.

    It contains a mapping from strings
    (which are names that identify a dataset, e.g. "coco_2014_train")
    to a function which parses the dataset and returns the samples in the
    format of `list[dict]`.

    The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details)
    if used with the data loader functionalities in `data/build.py,data/detection_transform.py`.

    The purpose of having this catalog is to make it easy to choose
    different datasets, by just using the strings in the config.
    """

    def register(self, name, func):
        """
        Args:
            name (str): the name that identifies a dataset, e.g. "coco_2014_train".
            func (callable): a callable which takes no arguments and returns a list of dicts.
                It must return the same results if called multiple times.
        """
        assert callable(func), "You must register a function with `DatasetCatalog.register`!"
        assert name not in self, "Dataset '{}' is already registered!".format(name)
        self[name] = func

    def get(self, name):
        """
        Call the registered function and return its results.

        Args:
            name (str): the name that identifies a dataset, e.g. "coco_2014_train".

        Returns:
            list[dict]: dataset annotations.
        """
        try:
            f = self[name]
        except KeyError as e:
            raise KeyError(
                "Dataset '{}' is not registered! Available datasets are: {}".format(
                    name, ", ".join(list(self.keys()))
                )
            ) from e
        return f()
DatasetCatalog = _DatasetCatalog()

所有的数据集都在Detectron2/detectron2/data/datasets/builtin.py中进行注册。以cityscapes数据集为例,注册的方法如下:

def register_all_cityscapes(root):
    for key, (image_dir, gt_dir) in _RAW_CITYSCAPES_SPLITS.items():
        meta = _get_builtin_metadata("cityscapes")
        image_dir = os.path.join(root, image_dir)
        gt_dir = os.path.join(root, gt_dir)

        inst_key = key.format(task="instance_seg")
        DatasetCatalog.register(
            inst_key,
            lambda x=image_dir, y=gt_dir: load_cityscapes_instances(
                x, y, from_json=True, to_polygons=True
            ),
        )
        MetadataCatalog.get(inst_key).set(
            image_dir=image_dir, gt_dir=gt_dir, evaluator_type="cityscapes_instance", **meta
        )

其中,传给DatasetCatalog的func为载入数据集的具体实现,针对不同的数据集需要单独定义,这些脚本都放在Detectron2/detectron2/data/datasets。对于cityscapes,Detectron2/detectron2/data/datasets/cityscapes.py脚本实现如下:

def load_cityscapes_instances(image_dir, gt_dir, from_json=True, to_polygons=True):
    """
    Args:
        image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
        gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
        from_json (bool): whether to read annotations from the raw json file or the png files.
        to_polygons (bool): whether to represent the segmentation as polygons
            (COCO's format) instead of masks (cityscapes's format).

    Returns:
        list[dict]: a list of dicts in Detectron2 standard format. (See
        `Using Custom Datasets </tutorials/datasets.html>`_ )
    """
    if from_json:
        assert to_polygons, (
            "Cityscapes's json annotations are in polygon format. "
            "Converting to mask format is not supported now."
        )
    files = get_cityscapes_files(image_dir, gt_dir)

    logger.info("Preprocessing cityscapes annotations ...")
    # This is still not fast: all workers will execute duplicate works and will
    # take up to 10m on a 8GPU server.
    pool = mp.Pool(processes=max(mp.cpu_count() // get_world_size() // 2, 4))

    ret = pool.map(
        functools.partial(cityscapes_files_to_dict, from_json=from_json, to_polygons=to_polygons),
        files,
    )
    logger.info("Loaded {} images from {}".format(len(ret), image_dir))

    # Map cityscape ids to contiguous ids
    from cityscapesscripts.helpers.labels import labels

    labels = [l for l in labels if l.hasInstances and not l.ignoreInEval]
    dataset_id_to_contiguous_id = {l.id: idx for idx, l in enumerate(labels)}
    for dict_per_image in ret:
        for anno in dict_per_image["annotations"]:
            anno["category_id"] = dataset_id_to_contiguous_id[anno["category_id"]]
    return ret

其中,def cityscapes_files_to_dict(files, from_json, to_polygons)返回的字典为:

def cityscapes_files_to_dict(files, from_json, to_polygons):
        ret = {
            "file_name": image_file,
            "image_id": os.path.basename(image_file),
            "height": inst_image.shape[0],
            "width": inst_image.shape[1],
            "annotations": annos,
        }

其中,标注的字典为:

            anno = {}
            anno["iscrowd"] = label_name.endswith("group")
            anno["category_id"] = label.id
            anno["segmentation"] = poly_coord
            anno["bbox"] = (xmin, ymin, xmax, ymax)
            anno["bbox_mode"] = BoxMode.XYXY_ABS
            annos.append(anno)

3. 将数据集转化为可供模型使用的数据

准备供模型使用的数据:首先,调用上面读入数据的API,返回上述数据集的字典的列表,每个字典存储的是一张图像路径、图像的标注信息,列表对应所有的图像。随后,需要以list[dict]信息为输入,制作模型可以使用的数据,包括把图像读入到内存、进行数据增广、分成batch。
Detectron2/detectron2/data/build.py

def build_detection_train_loader(cfg, mapper=None):
    """
    The batched ``list[mapped_dict]`` is what this dataloader will yield.

    Args:
        cfg (CfgNode): the config
        mapper (callable): a callable which takes a sample (dict) from dataset and
            returns the format to be consumed by the model.
            By default it will be ``DatasetMapper(cfg, True)``.

    Returns:
        an infinite iterator of training data
    """
    dataset_dicts = get_detection_dataset_dicts(   # 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
        cfg.DATASETS.TRAIN,
        filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
        min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
        if cfg.MODEL.KEYPOINT_ON
        else 0,
        proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
    )
    dataset = DatasetFromList(dataset_dicts, copy=False)

    if mapper is None:
        mapper = DatasetMapper(cfg, True)
    dataset = MapDataset(dataset, mapper)  #2. transform the lightweight representation of a dataset item into a format that is ready for the model to consume: including, e.g., read images, perform random data augmentation and convert to torch Tensors). 

    return build_batch_data_loader(  #The outputs of the mapper are batched (simply into a list). This batched data is the output of the data loader.
        dataset,
        sampler,
        cfg.SOLVER.IMS_PER_BATCH,
        aspect_ratio_grouping=cfg.DATALOADER.ASPECT_RATIO_GROUPING,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
    )

3.1. 根据配置文件的参数,利用数据集API中的class DatasetCatalog获得数据集的列表,返回类型为list[dict],该列表描述了数据集的信息(lightweight format)。

def get_detection_dataset_dicts(
    dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
):
    dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]

3.2. 把描述数据集信息的列表转化为可以被模型使用的数据格式

a. 首先,读取图片到内存,并转化为torch可用的数据集格式

Detectron2/detectron2/data/common.py中定义:

class DatasetFromList(data.Dataset):
    """
    Wrap a list to a torch Dataset. It produces elements of the list as data.
    """

    def __init__(self, lst: list, copy: bool = True, serialize: bool = True):
        """
        Args:
            lst (list): a list which contains elements to produce.
            copy (bool): whether to deepcopy the element when producing it,
                so that the result can be modified in place without affecting the
                source in the list.
            serialize (bool): whether to hold memory using serialized objects, when
                enabled, data loader workers can use shared RAM from master
                process instead of making a copy.
        """
        self._lst = lst
        self._copy = copy
        self._serialize = serialize

        def _serialize(data):
            buffer = pickle.dumps(data, protocol=-1)
            return np.frombuffer(buffer, dtype=np.uint8)

        if self._serialize:
            logger = logging.getLogger(__name__)
            logger.info(
                "Serializing {} elements to byte tensors and concatenating them all ...".format(
                    len(self._lst)
                )
            )
            self._lst = [_serialize(x) for x in self._lst]
            self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
            self._addr = np.cumsum(self._addr)
            self._lst = np.concatenate(self._lst)
            logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024 ** 2))

    def __len__(self):
        if self._serialize:
            return len(self._addr)
        else:
            return len(self._lst)

    def __getitem__(self, idx):
        if self._serialize:
            start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
            end_addr = self._addr[idx].item()
            bytes = memoryview(self._lst[start_addr:end_addr])
            return pickle.loads(bytes)
        elif self._copy:
            return copy.deepcopy(self._lst[idx])
        else:
            return self._lst[idx]

以上类继承自pytorch定义的基类:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

b. 图像和标注数据载入到内存,进行数据增强,并转化为tensor

其中,载入图像和标注数据到内存,并转化为tensor是定义在Detectron2/detectron2/data/dataset_mapper.py中的class DatasetMapper类:

class DatasetMapper:
    """
    A callable which takes a dataset dict in Detectron2 Dataset format,
    and map it into a format used by the model.

    This is the default callable to be used to map your dataset dict into training data.
    You may need to follow it to implement your own one for customized logic,
    such as a different way to read or transform images.
    See :doc:`/tutorials/data_loading` for details.

    The callable currently does the following:

    1. Read the image from "file_name"
    2. Applies cropping/geometric transforms to the image and annotations
    3. Prepare data and annotations to Tensor and :class:`Instances`
    """

    def __call__(self, dataset_dict):
        """
        Args:
            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.

        Returns:
            dict: a format that builtin models in detectron2 accept
        """
        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
        # USER: Write your own image loading if it's not from a file
        image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
        utils.check_image_size(dataset_dict, image)

        # USER: Remove if you don't do semantic/panoptic segmentation.
        if "sem_seg_file_name" in dataset_dict:
            sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
        else:
            sem_seg_gt = None

        aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
        transforms = self.augmentations(aug_input)
        image, sem_seg_gt = aug_input.image, aug_input.sem_seg

        image_shape = image.shape[:2]  # h, w
        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
        # Therefore it's important to use torch.Tensor.
        dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
        if sem_seg_gt is not None:
            dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))

        # USER: Remove if you don't use pre-computed proposals.
        # Most users would not need this feature.
        if self.proposal_topk is not None:
            utils.transform_proposals(
                dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
            )

        if not self.is_train:
            # USER: Modify this if you want to keep them for some reason.
            dataset_dict.pop("annotations", None)
            dataset_dict.pop("sem_seg_file_name", None)
            return dataset_dict

        if "annotations" in dataset_dict:
            # USER: Modify this if you want to keep them for some reason.
            for anno in dataset_dict["annotations"]:
                if not self.use_instance_mask:
                    anno.pop("segmentation", None)
                if not self.use_keypoint:
                    anno.pop("keypoints", None)

            # USER: Implement additional transformations if you have other types of data
            annos = [
                utils.transform_instance_annotations(
                    obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
                )
                for obj in dataset_dict.pop("annotations")
                if obj.get("iscrowd", 0) == 0
            ]
            instances = utils.annotations_to_instances(
                annos, image_shape, mask_format=self.instance_mask_format
            )

            # After transforms such as cropping are applied, the bounding box may no longer
            # tightly bound the object. As an example, imagine a triangle object
            # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
            # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
            # the intersection of original bounding box and the cropping box.
            if self.recompute_boxes:
                instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
            dataset_dict["instances"] = utils.filter_empty_instances(instances)
        return dataset_dict

随后,在Detectron2/detectron2/data/common.py中的MapDataset类中调用类的对象,即作为map_func输入:

class MapDataset(data.Dataset):
    """
    Map a function over the elements in a dataset.

    Args:
        dataset: a dataset where map function is applied.
        map_func: a callable which maps the element in dataset. map_func is
            responsible for error handling, when error happens, it needs to
            return None so the MapDataset will randomly use other
            elements from the dataset.
    """

    def __init__(self, dataset, map_func):
        self._dataset = dataset
        self._map_func = PicklableWrapper(map_func)  # wrap so that a lambda will work

        self._rng = random.Random(42)
        self._fallback_candidates = set(range(len(dataset)))

    def __len__(self):
        return len(self._dataset)

    def __getitem__(self, idx):
        retry_count = 0
        cur_idx = int(idx)

        while True:
            data = self._map_func(self._dataset[cur_idx])
            if data is not None:
                self._fallback_candidates.add(cur_idx)
                return data

            # _map_func fails for this idx, use a random new index from the pool
            retry_count += 1
            self._fallback_candidates.discard(cur_idx)
            cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]

            if retry_count >= 3:
                logger = logging.getLogger(__name__)
                logger.warning(
                    "Failed to apply `_map_func` for idx: {}, retry count: {}".format(
                        idx, retry_count
                    )
                )

c. 转化为batched的torch Tensors

Detectron2/detectron2/data/build.py中定义:

def build_batch_data_loader(
    dataset, sampler, total_batch_size, *, aspect_ratio_grouping=False, num_workers=0
):
    """
    Build a batched dataloader for training.

    Args:
        dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed.
        sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices
        total_batch_size (int): total batch size across GPUs.
        aspect_ratio_grouping (bool): whether to group images with similar
            aspect ratio for efficiency. When enabled, it requires each
            element in dataset be a dict with keys "width" and "height".
        num_workers (int): number of parallel data loading workers

    Returns:
        iterable[list]. Length of each list is the batch size of the current
            GPU. Each element in the list comes from the dataset.
    """
    world_size = get_world_size()
    assert (
        total_batch_size > 0 and total_batch_size % world_size == 0
    ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
        total_batch_size, world_size
    )

    batch_size = total_batch_size // world_size
    if aspect_ratio_grouping:
        data_loader = torch.utils.data.DataLoader(
            dataset,
            sampler=sampler,
            num_workers=num_workers,
            batch_sampler=None,
            collate_fn=operator.itemgetter(0),  # don't batch, but yield individual elements
            worker_init_fn=worker_init_reset_seed,
        )  # yield individual mapped dict
        return AspectRatioGroupedDataset(data_loader, batch_size)
    else:
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler, batch_size, drop_last=True
        )  # drop_last so the batch always have the same size
        return torch.utils.data.DataLoader(
            dataset,
            num_workers=num_workers,
            batch_sampler=batch_sampler,
            collate_fn=trivial_batch_collator,
            worker_init_fn=worker_init_reset_seed,
        )

其中,DataLoader定义在torch.utils.dataloader.py路径下:

class DataLoader(object):
    def __iter__(self):  #__iter__() 方法返回一个可迭代对象, 这个可迭代对象实现了 __next__() 方法并通过 StopIteration 异常标识迭代的完成。
        return _DataLoaderIter(self)

DataLoader是可迭代的(定义了__iter__方法)。

_DataLoaderIter为迭代器(同时定义了__next__和__iter__方法)。

class _DataLoaderIter(object):
    def __next__(self):  # __next__() 方法(Python 2 里是 next())会返回下一个迭代器对象
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self._get_batch()
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            return self._process_next_batch(batch)

    next = __next__  # Python 2 compatibility

    def __iter__(self):
        return self

4. 数据集的调用

在训练的类class DefaultTrainer中调用载入数据的接口build_detection_train_loader:

class DefaultTrainer(SimpleTrainer):
    def __init__(self, cfg):
        data_loader = self.build_train_loader(cfg)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable

        It now calls :func:`detectron2.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_detection_train_loader(cfg)

DefaultTrainer的父类class SimpleTrainer中的每步训练run_step调用数据的迭代器供训练使用:

class SimpleTrainer(TrainerBase):   #(class:`SimpleTrainer`, only does minimal SGD training and nothing else)
    def __init__(self, model, data_loader, optimizer): 
        super().__init__()
        self.data_loader = data_loader
        self._data_loader_iter = iter(data_loader)
    def run_step(self):  #仅展示关键步骤
        data = next(self._data_loader_iter)

第六章 其他基础软件功能

1. 模型存储与加载

1.1 模型

pytorch可以存储模型的参数或者存储整个模型,一般使用.pt或者.pth文件。

方法一:存储和加载模型的权重(推荐)

 Pytorch中训练的参数(weights and biases)包含在模型(上面讲到都继承自class torch.nn.Module)的parameters(可通过model.parameters()方法访问)中。这里统一把weights and biases称为权重。state_dict 是python的字典对象,把每层和它对应的parameter tensor对应起来,可以方便地存储和加载。

存储权重:

torch.save(model.state_dict(), PATH)

加载权重(需要先加载model,随后再把权重加载给model):

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:加载模型的权重时可能会要到模型和权重不是完全匹配的,即存在一些层出现missing_keys(model中存在的层,但是在权重中找不到对应的tensor)、unexpected_keys(model中不存在的层,但是在weight中存在的tensors)、冲突(model中的层对应的size和weight中tensor的size不一致),特别是使用backbone预训练模型时几乎肯定会遇到这些问题,加载权重时需要对这些情况进行处理。  

方法二:存储和加载整个model

也可以存储整个model(model的parameter数据成员包含权重)。但需要注意:这种方法并不是存储了模型类本身,而是存储了包含模型类文件的路径,在加载时再调用模型类。当路径改变或模型类修改后可能会失效。

存储model

torch.save(model, PATH)

加载model

model = torch.load(PATH)
model.eval()

1.2. checkpoint

为了载入checkpoint后继续训练,需要checkpoint保存更多的信息,完整的checkpoint一般包含了模型的state_dict、优化器的state_dict(optimizer对象torch.optim也含有state_dict,包含优化器的state和使用的超参信息,例如lr\momentum\weight_decay等)、上次训练正在使用的epoch、上次训练记录的loss。

存储checkpoint:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

 加载checkpoint:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

即先加载checkpoint,随后,再从加载好的checkpoint中加载模型等各个要素。

1.3. detectron2中checkpoint

detectron2中的checkpoint核心调用的是fvcore软件包中的fvcore.common.checkpoint.py中,fvcore.common.checkpoint.py中关键步骤的调用上述pytorch的模型存储和加载的API;detectron2中的checkpoint中可以根据加载的具体模型格式等对def _load_file() , def _load_model()等方法进行重写。在detectron2中以hooks的形式使用,可以实现固定迭代周期存储checkpoint,以及训练开始时加载checkpoint。

2. 日志存储

日志存储使用EventStorage类提供暂存信息的方法,在EventWriter中读取EventStorage的信息进行写入log,整个过程通过PeriodicWriter控制。下面我们从使用的角度回推,看完整的过程:

2.1 控制记录日志的流程

日志存储也是通过hooks来控制的,使用的是class PeriodicWriter()类, 定义在Detectron2\detectron2\engine\hooks.py,其中,def after_step()方法固定周期地调用writer.write()进行写信息。

class PeriodicWriter(HookBase):
    def __init__(self, writers, period=20):
        """
        Args:
            writers (list[EventWriter]): a list of EventWriter objects
            period (int):
        """
        self._writers = writers
        for w in writers:
            assert isinstance(w, EventWriter), w
        self._period = period

    def after_step(self):
        if (self.trainer.iter + 1) % self._period == 0 or (
            self.trainer.iter == self.trainer.max_iter - 1
        ):
            for writer in self._writers:
                writer.write()

2.2 读取信息&写入日志

其中,传入给PeriodicWriter类的wirters都继承自EventWriter类,并定义写日志时具体执行哪些操作。

class EventWriter:
    """
    Base class for writers that obtain events from :class:`EventStorage` and process them.
    """

    def write(self):
        raise NotImplementedError

    def close(self):
        pass

以class CommonMetricPrinter(EventWriter)为例:

class CommonMetricPrinter(EventWriter):
    """
    Print **common** metrics to the terminal, including
    iteration time, ETA, memory, all losses, and the learning rate.
    It also applies smoothing using a window of 20 elements.

    It's meant to print common metrics in common ways.
    To print something in more customized ways, please implement a similar printer by yourself.
    """

    def __init__(self, max_iter):
        """
        Args:
            max_iter (int): the maximum number of iterations to train.
                Used to compute ETA.
        """
        self.logger = logging.getLogger(__name__)
        self._max_iter = max_iter
        self._last_write = None

    def write(self):
        storage = get_event_storage()
        iteration = storage.iter

        try:
            data_time = storage.history("data_time").avg(20)
        except KeyError:
            # they may not exist in the first few iterations (due to warmup)
            # or when SimpleTrainer is not used
            data_time = None

        eta_string = None
        try:
            iter_time = storage.history("time").global_avg()
            eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
            storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        except KeyError:
            iter_time = None
            # estimate eta on our own - more noisy
            if self._last_write is not None:
                estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
                    iteration - self._last_write[0]
                )
                eta_seconds = estimate_iter_time * (self._max_iter - iteration)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
            self._last_write = (iteration, time.perf_counter())

        try:
            lr = "{:.5g}".format(storage.history("lr").latest())
        except KeyError:
            lr = "N/A"

        if torch.cuda.is_available():
            max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
        else:
            max_mem_mb = None

        # NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
        self.logger.info(
            " {eta}iter: {iter}  {losses}  {time}{data_time}lr: {lr}  {memory}".format(
                eta=f"eta: {eta_string}  " if eta_string else "",
                iter=iteration,
                losses="  ".join(
                    [
                        "{}: {:.4g}".format(k, v.median(20))
                        for k, v in storage.histories().items()
                        if "loss" in k
                    ]
                ),
                time="time: {:.4f}  ".format(iter_time) if iter_time is not None else "",
                data_time="data_time: {:.4f}  ".format(data_time) if data_time is not None else "",
                lr=lr,
                memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
            )
        )

其中,通过storage = get_event_storage()获取待存储的信息,该函数实现的功能是从全局的_CURRENT_STORAGE_STACK = []中获取最新的结果,对于一个storage,内部是通过history字典存储信息的,例如,上面需要读取data_time信息需要调用storage.history()。

data_time = storage.history("data_time").avg(20)

2.3 暂存信息

_CURRENT_STORAGE_STACK信息是由EventStorage类写入的,对于一个storage,其history信息的写入和读取也是在EventStorage类中定义的。

class EventStorage:
    def __init__(self, start_iter=0):
        """
        Args:
            start_iter (int): the iteration number to start with
        """
        self._history = defaultdict(HistoryBuffer)
    def put_scalar(self, name, value, smoothing_hint=True):
        """
        Add a scalar `value` to the `HistoryBuffer` associated with `name`.

        Args:
            smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
                smoothed when logged. The hint will be accessible through
                :meth:`EventStorage.smoothing_hints`.  A writer may ignore the hint
                and apply custom smoothing rule.

                It defaults to True because most scalars we save need to be smoothed to
                provide any useful signal.
        """
        name = self._current_prefix + name
        history = self._history[name]
        value = float(value)
        history.update(value, self._iter)
        self._latest_scalars[name] = (value, self._iter)

        existing_hint = self._smoothing_hints.get(name)
        if existing_hint is not None:
            assert (
                existing_hint == smoothing_hint
            ), "Scalar {} was put with a different smoothing_hint!".format(name)
        else:
            self._smoothing_hints[name] = smoothing_hint

    def history(self, name):
        """
        Returns:
            HistoryBuffer: the scalar history for name
        """
        ret = self._history.get(name, None)
        if ret is None:
            raise KeyError("No history metric available for {}!".format(name))
        return ret

    def __enter__(self):
        _CURRENT_STORAGE_STACK.append(self)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        assert _CURRENT_STORAGE_STACK[-1] == self
        _CURRENT_STORAGE_STACK.pop()

EventStorage使用了上下文管理,即实现了__enter__和__exit__方法,其中,__enter__在使用with语句开始运行时调用,__exit__方法在with语句运行结束后调用。

回忆下,train中提到的创建EventStorage上下文管理器的操作:

        with EventStorage(start_iter) as self.storage:
            try:
                self.before_train()
                for self.iter in range(start_iter, max_iter):
                    self.before_step()
                    self.run_step()
                    self.after_step()
            except Exception:
                logger.exception("Exception during training:")
                raise
            finally:
                self.after_train()

在训练run_step()中调用storage.put_scalar()方法:

self.storage.put_scalar("data_time", data_time)

参考:https://pytorch.org/tutorials/beginner/saving_loading_models.html

猜你喜欢

转载自blog.csdn.net/Cxiazaiyu/article/details/109099828