MMDetection Framework Getting Started Tutorial (Full Version)

  There seem to be a lot of MMDetection tutorials on the Internet, but they don’t feel systematic. After reading it, I still don’t know how to use MMDetection. Here it is still recommended to follow the official tutorial directly and learn MMDetection combined with the source code. The related links are summarized as follows:

  1. Official Tutorial - MMCV
  2. Official Tutorial - MMDetection
  3. Official Tutorial - MMDetection learning route that you must know (personal experience version)
  4. Xi'an Jiaotong University courseware - mmdetection tutorial (use)

  This article will introduce how to build your own algorithm from scratch in MMDetection. The first few blogs are my notes in the learning process. I analyzed the principle of MMDetection from the source code itself, which is relatively detailed. This blog will reorganize the use method and process principle of MMDetection from a macro perspective, which can be regarded as a learning process for the previous month. summary.

  1. MMDetection framework introductory tutorial (1): installation tutorial under Anaconda3 (mmdet+mmdet3d)
  2. MMDetection Framework Getting Started Tutorial (2): Quick Start Tutorial
  3. MMDetection Framework Getting Started Tutorial (3): Detailed Analysis of Configuration Files
  4. MMDetection Framework Introductory Tutorial (4): Detailed Registration Mechanism
  5. MMDetection Framework Introductory Tutorial (5): Detailed Analysis of Runner and Hook

1. Framework overview

  MMDetection is an open source project launched by SenseTime and the Chinese University of Hong Kong for target detection tasks. It implements a large number of target detection algorithms based on Pytorch, and encapsulates the processes of data set construction, model construction, and training strategies into modules. In the way of module calling, we can implement a new algorithm with a small amount of code, which greatly improves the code reuse rate .

  In addition to MMDetection, the entire MMLab family also includes open source projects such as MMTracking for target tracking tasks and MMDetection3D for 3D target detection tasks. They are all based on Pytorch and MMCV. Pytorch does not need too much introduction. MMCV is a basic library for computer vision. The main function is to provide a general training framework based on Pytorch . For example, the Registry, Runner, Hook and other functions we often mention are all supported in MMCV . In addition, MMCV also provides general-purpose IO interfaces, multiple CNN network structures, and common CUDA operators with high-quality implementations, which will not be further expanded here.

2. The overall process of the framework

2.1 Pytorch

  When we use Pytorch to build a new algorithm, it usually includes the following steps:

  • Build a data set : Create a new class, inherit Datasetthe class, rewrite __getitem__()the method to realize the loading and traversal function of data and tags, and define the data preprocessing process in the form of pipeline
  • Build a data loader : pass in the corresponding parameters to instantiate the DataLoader
  • Build a model : Create a new class, inherit Modulethe class, and rewrite forward()the forward process of the function definition model
  • Define the loss function and optimizer : choose the appropriate and loss function and optimizer according to the algorithm
  • Training and verification : cyclically obtain data and labels from DataLoader, send them to the network model, calculate loss, and use the optimizer to perform iterative optimization according to the gradient of the backpropagation
  • Other operations : operations such as training tricks, log printing, and checkpoint saving can be interspersed arbitrarily in the main calling function

2.2 MMDetection

  When using Pytorch to build a new algorithm, it usually includes the following steps:

  • Registration data set : CustomDatasetIt is a re-encapsulation of MMDetection on the original Datasetbasis, and its __getitem__()methods will be redirected to prepare_train_img()and prepare_test_img()functions according to the training and testing modes. When users CustomDatasetbuild their own datasets by inheriting classes, they need to rewrite load_annotations()and get_ann_info()functions to define the loading and traversal methods of data and labels. After completing the definition of the dataset class, you also need to use DATASETS.register_module()the module registration.
  • Registering the model : The way the model is built is similar to that of Pytorch, which is to create a new Modulesubclass and then rewrite forward()the function. The only difference is that MMDetection needs to be inherited BaseModuleinstead of Module, BaseModuleyes Modulesubclass, any model in MMLab must inherit from this class. In addition, MMDetection splits a complete model into backbone, neck, and head for management, so users need to disassemble the algorithm model into three classes in this way, and use, and complete module BACKBONES.register_module()registration NECKS.register_module()respectively HEADS.register_module().
  • Build a configuration file : The configuration file is used to configure the operating parameters of each component of the algorithm, and can generally include four parts: datasets, models, schedules, and runtime. After completing the definition and registration of the corresponding module, configure the corresponding operating parameters in the configuration file, and then MMDetection will Registryread and parse the configuration file through the class to complete the instantiation of the module. In addition, configuration files can _base_implement inheritance functions through fields to improve code reuse.
  • Training and verification : After completing the code implementation of each module, module registration, and configuration file writing, the model can be used ./tools/train.pyand ./tools/test.pytrained and verified without requiring users to write additional code.

2.3 Process comparison

  Although MMDetection is quite different from Pytorch's algorithm implementation steps in terms of steps, the underlying logic implementation is essentially the same as Pytorch. You can refer to the following figure for comparison. The blue part represents the Pytorch process, and the orange part Indicates the MMDetection process, and the green part indicates a general process that has nothing to do with the algorithm framework.

  Before starting to get in touch with the algorithm implementation process of MMDetection, you must first have a general understanding of the registration mechanism and Hook mechanism. It is recommended to read it quickly and have a general understanding of the registration mechanism and Hook mechanism. After reading Chapter 5 Looking back at the details of the registration mechanism and Hook mechanism will give you a deeper understanding.

3. Registration mechanism

3.1 Registry class

  As a downstream project of MMCV, MMDetection inherits MMCV's module management method - the registration mechanism. To put it simply, the registration mechanism is to maintain several lookup tables, the key is the name of the module, and the value is the handle of the module. Each lookup table manages a batch of different modules with similar functions. Every time we create a new module, we must save the corresponding key-valuequery pair in the corresponding query table according to the function realized by the module. This saving process is called " registration ". When we want to call a module, we only need to find the corresponding module handle from the lookup table according to the module name, and then we can complete operations such as module initialization or method calling. MMCV Registryimplements the mapping from strings (key) to classes (value) through classes.

  The constructor of the Registry is as follows. The variable self._module_dictis the "lookup table" mentioned above. The registered modules will be stored in this variable of dictionary type. Creating a Registry instance means creating a new lookup table. In addition, Registry also supports the inheritance mechanism.

from mmcv.utils import Registry

class Registry:
	# 构造函数
    def __init__(self, name, build_func=None, parent=None, scope=None):
        # 注册器的名称
        self._name = name
        # 使用module_dict管理字符串到类的映射
        self._module_dict = dict()
        # 使用children管理注册器的子类
        self._children = dict()

        # build_func按照如下优先级初始化:
        # 1. build_func: 优先使用指定的函数
        # 2. parent.build_func: 其次使用父类的build_func
        # 3. build_from_cfg: 默认从config dict中实例化对象
        if build_func is None:
            if parent is not None:
                self.build_func = parent.build_func
            else:
                self.build_func = build_from_cfg
        else:
            self.build_func = build_func
            
        # 设置父类-子类的从属关系
        if parent is not None:
            assert isinstance(parent, Registry)
            parent._add_children(self)
            self.parent = parent
        else:
            self.parent = None

register_module()The registration of the module is realized   through the member function of the Registry , register_module()and another private function is called internally _register_module(). The core function of the module registration is actually _register_module()realized in the Registry. The core code is also very simple, which is to save the incoming module_namesum in the dictionary.module_classself._module_dict

def _register_module(self, module_class, module_name=None, force=False):
	# 如果未指定模块名称则使用默认名称
    if module_name is None:
        module_name = module_class.__name__
        
    # 为了支持在nn.Sequentail中构建pytorch模块, module_name为list形式
    if isinstance(module_name, str):
        module_name = [module_name]
        
    for name in module_name:
    	# 如果force=False, 则不允许注册相同名称的模块
    	# 如果force=True, 则用后一次的注册覆盖前一次
        if not force and name in self._module_dict:
            raise KeyError(f'{
      
      name} is already registered in {
      
      self.name}')
        # 将当前注册的模块加入到查询表中
        self._module_dict[name] = module_class

  After we get the handle of a module through the string, we can self.build_funcinstantiate the module through the function handle. build_funcIt can be specified manually or inherited from the parent class. Generally speaking, build_from_cfg()the function is used by default, that is, cfgthe module is initialized with configuration parameters. The configuration parameter cfgis a dictionary, the fields in it typeare strings of module names, and the other fields correspond to the input parameters of the module constructor.

def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()
	# 将cfg以外的外部传入参数也合并到args中
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
            
	# 获取模块名称
    obj_type = args.pop('type')
    if isinstance(obj_type, str):
    	# get函数返回registry._module_dict中obj_type对应的模块句柄
        obj_cls = registry.get(obj_type)		
        if obj_cls is None:
            raise KeyError(f'{
      
      obj_type} is not in the {
      
      registry.name} registry')
    elif inspect.isclass(obj_type):
    	# type值是模块本身
        obj_cls = obj_type
    else:
        raise TypeError(f'type must be a str or valid type, but got {
      
      type(obj_type)}')
    
    # 模块初始化, 返回模块实例
    try:
        return obj_cls(**args)
    except Exception as e:
        raise type(e)(f'{
      
      obj_cls.__name__}: {
      
      e}')

  Considering that registrythe parameter needs to point to the current registry itself, we generally call build()the method of the Registry class instead self.build_func.

def build(self, *args, **kwargs):
    return self.build_func(*args, **kwargs, registry=self)

  The following is a small example that simulates the registration and calling process of the network model. Note that when we print the Registry object, we actually print self._module_dictthe values ​​in it.

# 实例化一个注册器用来管理模型
MODELS = Registry('myModels')

# 方式1: 在类的创建过程中, 使用函数装饰器进行注册(推荐)
@MODELS.register_module()
class ResNet(object):
    def __init__(self, depth):
        self.depth = depth
        print('Initialize ResNet{}'.format(depth))

# 方式2: 完成类的创建后, 再显式调用register_module进行注册(不推荐)   
class FPN(object):
    def __init__(self, in_channel):
        self.in_channel= in_channel
        print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)

print(MODELS)
""" 打印结果为:
Registry(name=myModels, items={'ResNet': <class '__main__.ResNet'>, 'FPN': <class '__main__.FPN'>})
"""

# 配置参数, 一般cfg从配置文件中获取
backbone_cfg = dict(type='ResNet', depth=101)
neck_cfg = dict(type='FPN', in_channel=256)
# 实例化模型(将配置参数传给模型的构造函数), 得到实例化对象
my_backbone = MODELS.build(backbone_cfg)
my_neck = MODELS.build(neck_cfg)
print(my_backbone, my_neck)
""" 打印结果为:
Initialize ResNet101
Initialize FPN256
<__main__.ResNet object at 0x000001E68E99E198> <__main__.FPN object at 0x000001E695044B38>
"""

3.2 Summary of Registration Mechanism

  The registration mechanism is a means of module management. Modules are grouped and managed according to different module functions. Each group is maintained by a query table. The query table records the mapping relationship between the module name (string) and the module itself (itself) , the process of recording the mapping relationship to the query table is called "registration" . Once the module is registered, the specific module handle can be easily indexed according to the module name, and then the module can be initialized and used according to the normal program flow. The registration and use of a module consists of 5 steps:

  1. Create a new class to implement custom functions
  2. Register this class into the corresponding query table ( register_module)
  3. Specify the initialization parameters of the module in the configuration file
  4. Instantiate the module through the build function ( build_from_cfg)
  5. Use this instance object to execute the function function

4. Hook mechanism

4.1 Hook class

  The entire algorithm process of MMDetection is like a black box: after the input (configuration file) is given, the black box will spit out the algorithm results. The whole process is highly encapsulated, and almost no code needs to be written by hand, but how do we add custom operations in the algorithm execution process? This is the role of the Hook mechanism.

  Simply put, Hook can be understood as a trigger that can execute a predefined function at a predefined location in the program . MMCV pre-defines 6 sites where user-defined functions can be inserted according to the life cycle of the algorithm, and users can freely insert any number of function operations at each site, as shown in the following figure:

  These 6 positions basically cover the positions where custom operations may appear. MMCV has implemented some commonly used Hooks. The default Hook does not require users to register themselves, and the corresponding parameters can be configured through the configuration file; custom Hooks require the user to configure in the configuration file. custom_hooksRegister in the manual configuration fields.

  HookThe class itself has very little code, and only provides interface functions at predefined locations. Any custom Hook needs to inherit the Hookclass, and then rewrite the corresponding interface functions as needed. For example, the checkpoint saving operation usually happens after each iteration or epoch, so we need to rewrite after_train_iterand after_train_epoch.

class Hook:
    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

@HOOKS.register_module()
class CheckpointHook(Hook):
	def __init__(self,
                 interval=-1,
                 by_epoch=True,
                 save_optimizer=True,
                 out_dir=None,
                 max_keep_ckpts=-1,
                 **kwargs):
        ...
	def after_train_iter(self, runner):
		...
	def after_train_epoch(self, runner):
		...

  Different from other modules, after we define a Hook (and register it in HOOKSthe register), it needs to be registered in the Runner before it can be used. There are two registrations before and after . The first registration HOOKSis to allow the program to find the corresponding module according to the Hook name, and the second registration to the Runner is to call the corresponding function when the program executes to a predefined position.

  Runner is a class used by MMCV to manage the training process. It maintains a list type variable internally . We need to add all the Hook instance objects thatself._hooks will be called during the training process in order of priority . This process is realized through functions. MMCV pre-defines several priority levels. The smaller the number, the higher the priority. If you feel that the default grading method is too granular, you can also directly pass in an integer from 0 to 100 for fine division.self._hooksRunner.register_hook()

def register_hook(self, hook, priority='NORMAL'):
	"""预定义优先级
    +--------------+------------+
    | Level        | Value      |
    +==============+============+
    | HIGHEST      | 0          |
    +--------------+------------+
    | VERY_HIGH    | 10         |
    +--------------+------------+
    | HIGH         | 30         |
    +--------------+------------+
    | ABOVE_NORMAL | 40         |
    +--------------+------------+
    | NORMAL       | 50         |
    +--------------+------------+
    | BELOW_NORMAL | 60         |
    +--------------+------------+
    | LOW          | 70         |
    +--------------+------------+
    | VERY_LOW     | 90         |
    +--------------+------------+
    | LOWEST       | 100        |
    +--------------+------------+
    """
    hook.priority = priority
    # 插入法排序将Hooks按照priority大小升序排列
    inserted = False
    for i in range(len(self._hooks) - 1, -1, -1):
        if priority >= self._hooks[i].priority:
            self._hooks.insert(i + 1, hook)
            inserted = True
            break
    if not inserted:
        self._hooks.insert(0, hook)

  After the Hook instance is added to it self._hooks, it can then be called at a predefined location call_hook()to call the corresponding method in each Hook instance. call_hook()called callback function.

# 开始运行时调用
self.call_hook('after_train_epoch')

while self.epoch < self._max_epochs:

    # 开始 epoch 迭代前调用
    self.call_hook('before_train_epoch')

    for i, data_batch in enumerate(self.data_loader):
        # 开始 iter 迭代前调用
        self.call_hook('before_train_iter')

        self.model.train_step()

        # 经过一次迭代后调用
        self.call_hook('after_train_iter')

    # 经过一个 epoch 迭代后调用
    self.call_hook('after_train_epoch')

# 运行完成前调用
self.call_hook('after_train_epoch')

  When calling, call_hook()it will traverse self._hooksall Hook instances in the Hook instance, and call fn_namethe specified member function of the Hook instance. For fn_name='before_train_epoch'example, call_hook()all Hook before_train_epoch()functions will be called one by one. And since self._hooksthey have been sorted according to priority, call_hook()the Hook method with higher priority will be called first.

def call_hook(self, fn_name):
    for hook in self._hooks:
        getattr(hook, fn_name)(self)

4.2 Summary of Hook Mechanism

  Hook is a trigger set at a fixed position in the program. When the program executes to a preset position, it will trigger a breakpoint, execute the flow of the Hook function, and then return to the breakpoint position to continue executing the code of the main process. Implementing a Hook consists of 5 steps:

  1. Define a class that inherits the Hook base class
  2. Selectively rewrite the corresponding functions in the Hook base class according to the functions of the custom Hook
  3. Register the custom Hook module to the HOOKS query table ( register_module)
  4. Instantiate the Hook module and register it in the Runner ( register_hook)
  5. Use the callback function to call the rewritten Hook function ( call_hook)

5. Algorithm implementation process

  Section 2.2 mentioned that using MMDetection to implement a new algorithm includes four steps: registration data set, registration model, configuration file construction, and training/validation. To understand the algorithm implementation process of MMDetection, you must thoroughly understand the four classes of Config, Registry, Runner and Hook .

5.1 Registration Dataset

  When defining your own dataset, you need to write a new inherited CustomDatasetDataset class, and then rewrite load_annotations()functions and get_ann_info()functions. The official document says that if users want to use it CustomDataset, they need to convert the existing data set into an MMDetection-compatible format (COCO format or intermediate format). But I took a look at the underlying code and found no such limitation, as long as your data format can match what you have achieved load_annotations().get_ann_info()

"""
中间数据格式:
[
    {
        'filename': 'a.jpg',										# 图片路径
        'width': 1280,												# 图片尺寸
        'height': 720,
        'ann': {													# 标注信息
            'bboxes': <np.ndarray, float32> (n, 4),					# 标注框坐标(x1, y1, x2, y2)
            'labels': <np.ndarray, int64> (n, ),					# 标注框类别
            'bboxes_ignore': <np.ndarray, float32> (k, 4),			# 不关注的标注框坐标(可选)
            'labels_ignore': <np.ndarray, int64> (k, ) 				# 不关注的标注框类别(可选)
        }
    },
    ...
]
"""

class CustomDataset(Dataset):
    CLASSES = None
    def __init__(self,
                 ann_file,				# 文件路径
                 pipeline,				# 数据预处理pipeline
                 classes=None,			# 检测类别
                 data_root=None,		# 文件根路径
                 img_prefix='',			
                 seg_prefix=None,
                 proposal_file=None,
                 test_mode=False,		# 为True的话将不会加载标注信息
                 filter_empty_gt=True):	# 为True的话将会过滤没有标注框的图像(只在test_mode=False的条件下有效)
        self.ann_file = ann_file
        self.data_root = data_root
        self.img_prefix = img_prefix
        self.seg_prefix = seg_prefix
        self.proposal_file = proposal_file
        self.test_mode = test_mode
        self.filter_empty_gt = filter_empty_gt
        self.CLASSES = self.get_classes(classes)
        
        # 调用load_annotations函数加载样本和标签
        self.data_infos = self.load_annotations(self.ann_file)

        # 用户可以通过重写_filter_imgs()函数在训练过程中实现自定义的样本过滤功能
        if not test_mode:
            valid_inds = self._filter_imgs()
            self.data_infos = [self.data_infos[i] for i in valid_inds]

        # 根据pipeline对样本进行预处理
        self.pipeline = Compose(pipeline)

  The traversal in Pytorch Datasetis __getitem__()realized by rewriting the function, but CustomDatasetalthough MMDetection is Dataseta subclass of MMDetection, it does not require us to rewrite __getitem__()the function. The reason is that in order to facilitate data management in training mode and test mode, MMDetection has rewritten __getitem__()the function , can be called according to the current running mode prepare_train_img()or prepare_test_img(), the difference between the two is whether to load the training label. So we only need to rewrite the load_annotations()sum get_ann_info()function, and leave the rest to MMDetection.

def __getitem__(self, idx):
    if self.test_mode:
        return self.prepare_test_img(idx)
    else:
        return self.prepare_train_img(idx)

# 返回预处理后的训练样本及标签        
def prepare_train_img(self, idx):
    img_info = self.data_infos[idx]
    # 调用get_ann_info获取训练标签
    ann_info = self.get_ann_info(idx)
    results = dict(img_info=img_info, ann_info=ann_info)
    return self.pipeline(results)

# 返回预处理后的测试样本
def prepare_test_img(self, idx):
    img_info = self.data_infos[idx]
    results = dict(img_info=img_info)
    return self.pipeline(results)

  After completing the custom Dataset class, don't forget to add @DATASETS.register_module()the current module to the DATASETS table.

5.2 Registering the model

  The definition of the network model is relatively simple, and there are only two differences compared with Pytorch:

  1. The inherited parent class Modulechanged fromBaseModule
  2. The model needs to be disassembled into 3 parts according to the structure of backbone, neck and head, which are defined and registered in BACKBONES, NECKSand , respectively HEADS.

5.3 Build configuration files

  As mentioned in Section 2.2, under the MMDetection framework, we don't need to implement additional code for the iterative training/testing process, but only need to execute the ready-made train.py or test.py. But how does MMDetection know which modules we need? This is what configuration files do.

5.3.1 Configuration file composition

  The configuration file is a text file composed of a series of variable definitions, in which dictthe type of variable represents each module, and dictthe variable must contain typefields, which represent the module name, and other fields correspond to the parameters of the module constructor . Initialization of the module (see the function in Chapter 3 of this article build_from_cfg()). The module must be registered, otherwise the subsequent MMDetection cannot typefind the corresponding module based on the value. In addition to variables of type, the configuration file dictcan also be of any other type, which is generally an dictintermediate variable defined by auxiliary variables, such as:

test_pipeline = [
    dict(type='LoadMultiViewImageFromFiles', to_float32=True),
    dict(type='NormalizeMultiviewImage', **img_norm_cfg),
    dict(type='PadMultiViewImage', size_divisor=32)
]
evaluation = dict(interval=2, pipeline=test_pipeline)

  Configuration files also support inheritance operations, _base_implemented through variables. _base_Is a listtype variable that stores the path of the configuration file to be inherited. _base_When parsing configuration files, the file parser parses all configuration files recursively (other configuration files may also contain variables). Any configuration file going back up will inherit the following four files, corresponding to datasets, models, training strategies (schedules) and default runtime configuration (default_runtime):

_base_ = [
    'mmdetection/configs/_base_/models/fast_rcnn_r50_fpn.py',		# models
    'mmdetection/configs/_base_/datasets/coco_detection.py',		# datasets
    'mmdetection/configs/_base_/schedules/schedule_1x.py',			# schedules
    'mmdetection/configs/_base_/default_runtime.py',				# defualt_runtime
]

  If you print the configuration file that inherits the above four basic configuration files, you can see the following content, which is also the configuration information that any complete configuration file should contain . Of course, you can also add custom configuration information arbitrarily. So when we usually create a new configuration file, we usually inherit these 4 basic configuration files, and then make targeted adjustments on this basis.

# 1. 模型配置(models) =========================================
model = dict(
	type='FastRCNN',			# 模型名称是FastRCNN
	backbone=dict(				# BackBone是ResNet
        type='ResNet',
        ...,
    ),
    neck=dict(					# Neck是FPN
        type='FPN',
        ...,
    ),
    roi_head=dict(				# Head是StandardRoIHead
        type='StandardRoIHead',
        ...,
        loss_cls=dict(...),		# 分类损失函数
        loss_bbox=dict(...),	# 回归损失函数
    ),
    train_cfg=dict(				# 训练参数配置
    	assigner=dict(...),		# BBox Assigner
    	sampler=dict(...),		# BBox Sampler
    	...
	),
    test_cfg =dict(				# 测试参数配置
    	nms=dict(...),			# NMS后处理
    	...,
    )
)

# 2. 数据集配置(datasets) =========================================
dataset_type = '...'			# 数据集名称
data_root = '...'				# 数据集根目录
img_norm_cfg = dict(...)		# 图像归一化参数
train_pipeline = [				# 训练数据处理Pipeline
	...,
]
test_pipeline = [...]			# 测试数据处理Pipeline
data = dict(
	samples_per_gpu=2,			# batch_size
    workers_per_gpu=2,			# GPU数量
	train=dict(					# 训练集配置
		type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',	# 标注问加你
        img_prefix=data_root + 'train2017/',	# 图像前缀
		pipline=trian_pipline,					# 数据预处理pipeline
	),
	val=dict(					# 验证集配置
		...,
		pipline=test_pipline,
	),
	test=dict(					# 测试集配置
		...,
		pipline=test_pipline,
	)
)

# 3. 训练策略配置(schedules) =========================================
evaluation = dict(interval=1, metric='bbox')
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)

# 4. 运行配置(runtime) =========================================
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]

  In addition, there are some optional configuration parameters, for example custom_imports, used to import user-defined modules, when the configuration file parser resolves to this field, it will call import_modules_from_strings()a function to importsimport the modules contained in the field into the program.

custom_imports = dict(imports=['os.path', 'numpy'], 	# list类型, 需要导入的模块名称
					  allow_failed_imports=False)		# 如果设为True, 导入失败时会返回None而不是报错

5.3.2 Modification of configuration files

  There are two situations when modifying the configuration file:

  1. Modify a parameter of an existing dict: directly rewrite the corresponding parameter
  2. All the parameters of the original dict need to be deleted, and then replaced with a new set of parameters: add _delete_=Truefields

  Take modifying the learning rate and replacing the optimizer as examples to explain how to modify the configuration file in these two cases:

# 从_base_中继承的原始优化器
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)




# 修改学习率
optimizer = dict(lr=0.001)		
# 修改后optimizer变成
optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)




# 将原来的SGD替换成AdamW
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001)  
# 替换后optimizer变成
optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001)

5.3.3 Analysis of configuration files

  Parsing the configuration file is actually what train.py and test.py have to do. It will be discussed here together with building the configuration file, and the logic will be smoother.

  Generally use the Config class to manage configuration files. Use Config.fromfile(filename)to read the configuration file (you can also pass in a dict directly), return a Config class instance cfg, and then you can print(cfg.pretty_text)print the configuration file information by , or cfg.dump(filepath)save the configuration file information by .

from mmcv import Config

cfg = Config.fromfile('../configs/test_config.py')

  fromfile()The source code of the function is as follows, and its core function is _file2dict(). _file2dict()According to the text order, the configuration file will be parsed according to the format of key = value, and a cfg_dictdictionary named will be obtained. If there is a field, the function will be called again _base_for each file path included , and the configuration parameters contained in the file will be added to the , to implement the inheritance function of the configuration file. It should be noted that the key values ​​​​contained in different files will be verified internally , and duplicate key values ​​​​are not allowed in different basic configuration files , otherwise Config will not know which configuration file to prevail._base__file2dict()cfg_dict_file2dict()_base_

def fromfile(filename,
             use_predefined_variables=True,
             import_custom_modules=True):
    cfg_dict, cfg_text = Config._file2dict(filename,
                                           use_predefined_variables)
    # import_modules_from_strings()是根据字符串列表导入对应的模块
    if import_custom_modules and cfg_dict.get('custom_imports', None):
        import_modules_from_strings(**cfg_dict['custom_imports'])
    return Config(cfg_dict, cfg_text=cfg_text, filename=filename)

_file2dict()The format obtained by   calling and parsing cfg_dictis as follows. All the text information in the configuration file is converted into variables and stored in a dictionary type.

  There are two other points that need to be added. One is that when constructing the Config object, the python dictdata type will be converted into ConfigDicta type for processing. It is a subclass of ConfigDictthe third-party library addict ( also a subclass of python), because python's native type does not support the access method, especially when multiple layers of dict are nested inside, if the access method of key is used, the code writes It is very inefficient, and the class implements the access method through rewriting . Therefore, the inherited ones also support access to each member value in the dictionary.DictDictdictdict.属性dictDict__getattr__().属性DictConfigDict.属性

from mmcv import ConfigDict

model = ConfigDict(dict(backbone=dict(type='ResNet', depth=50)))

print(model.backbone.type)		# 输出 'ResNet'

  Second, in order to be compatible with the decimal point in the configuration file name, _file2dict()a temporary folder will be created under the C drive for operation. If the C drive has access rights settings, an error may appear, but this problem will only appear under the Windows system .

5.3.4 Configuration file summary

  To briefly review, the configuration file is a dicttext file containing multiple variables, each dictcorresponding to a specific module (the module must have been registered), dictmust have typefields, and other fields correspond to the construction parameters of the module. When build()the module is instantiated by the calling function, typethe corresponding module handle will be found from the lookup table according to the value of the string, and dictthe value of other fields in it will be used as the construction parameter to initialize the module.

5.4 Training and Testing

  Implementing an algorithm with MMDetection consists of four steps. The first and second steps register data sets and models to build basic modules (data streams and models), and the third step to build configuration files is to specify required modules and module inputs. Parameters, the next fourth step is to extract the pre-defined modules one by one according to the configuration file, pass in the specified input parameters, and then string them together in sequence according to the algorithm process.

5.4.1 train.py file

  Let's go through the official train.pycode (I only keep the core function code), and then introduce how MMDetection uses Runner and Hook to schedule the entire training process, so that it will be faster to understand.

  train.pyThe main calling function does four things. One is to use the Config class to parse the configuration file we built in the third step, then initialize the model and data set, and finally pass the model and data set into the function, ready to start train_detector()training process.

def main():
	# Step1: 解析配置文件, args.config是配置文件路径(如何解析配置文件可以参考本文4.3.3节)
	cfg = Config.fromfile(args.config)

	# Step2: 初始化模型, 函数内部调用的是DETECTORS.build(cfg)
	model = build_detector(cfg.model)
    # 初始化模型权重
    model.init_weights()
	
	# Step3: 初始化训练集和验证集, 函数内部调用build_from_cfg(cfg, DATASETS), 等价于DATASETS.build(cfg)
	datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline # 验证集在训练过程中使用train pipeline而不是test pipeline
        datasets.append(build_dataset(val_dataset))
    
    # Step4: 传入模型和数据集, 准备开始训练模型
    train_detector(model, datasets, cfg)

  train_detector()The function mainly builds the dataloader, initializes the optimizer, runner and hooks, and finally calls runner.run to start the formal iterative training process. It involves the concept of Runner, but we will not expand it here. We only need to know that Runner is also a module, which is responsible for the iterative training of the model.

def train_detector(model, dataset, cfg):
	# 获取Runner类型, EpochBasedRunner或IterBasedRuner
	runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner['type']
	
	# Step1: 获取dataloader, 因为dataset列表里包含了训练集和验证集, 所以使用for循环的方式构建dataloader
	# build_dataloader()会用DataLoader类进行dataloader的初始化
    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,		# batch_size
            runner_type=runner_type) for ds in dataset
    ]
	
	# Step2: 封装模型, 为了进行分布式训练
	model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
	
	# Step3: 初始化优化器
	optimizer = build_optimizer(model, cfg.optimizer)

	# Step4: 初始化Runner
	runner = build_runner(
        cfg.runner,
        default_args=dict(model=model, optimizer=optimizer)

	# Step5: 注册默认Hook(注册到runner._hooks列表中)
	runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
	
	# Step6: 注册自定义Hook(注册到runner._hooks列表中)
	 if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        for hook_cfg in cfg.custom_hooks:
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)

	# Step7: 开始训练流程
    if cfg.resume_from:
    	# 恢复检查点
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
    	# 加载预训练模型
        runner.load_checkpoint(cfg.load_from)
    # 调用run()方法, 开始迭代过程
    runner.run(data_loaders, cfg.workflow)

  Although the official train.pydocument is very long, the core code is actually all the operations we are familiar with in Pytorch. The entire train.py process is shown in the figure below

  1. First parse the incoming configuration file and instantiate each module in the configuration file;
  2. Then use the datasets structure data_loader, where the model uses MMDataParallel for a layer of encapsulation, mainly for subsequent distributed training;
  3. Then use data_loader and optimizer to initialize a Runner class object runner;
  4. Hooks that need to be used during the registration training process
  5. workflowExecute the runner.run() function according to the workflow specified by the configuration file for iterative training

  The following is an introduction to the inside of the runner.run() function.

5.4.2 Runner class

  Runner is divided into EpochBasedRunner and IterBasedRunner . As the name suggests, the former manages the process in the form of epoch, and the latter manages the process in the form of iter. They are both subclasses of BaseRunner. EpochBasedRunner and IterBasedRunner themselves do not override the constructor, and directly inherit the constructor of BaseRunner:

class BaseRunner(metaclass=ABCMeta):
	def __init__(self,
                 model,					# [torch.nn.Module] 要运行的模型
                 batch_processor=None,	# 该参数一般不使用
                 optimizer=None,		# [torch.optim.Optimizer] 优化器, 可以是一个也可以是一组通过dict配置的优化器
                 work_dir=None,			# [str] 保存检查点和Log的目录
                 logger=None,			# [logging.Logger] 训练中使用的日志记录器
                 meta=None,				# [dict] 一些信息, 这些信息会在logger hook中记录
                 max_iters=None,		# [int] 训练epoch数
                 max_epochs=None):		# [int] 训练迭代次数

  Any subclass of BaseRunner needs to implement the four methods of run(), train(), val()and , which are also the core methods of Runner. Next, take the EpochBasedRunner class as an example to analyze these four functions in detail.save_checkpoint()

run() function
  run() is the calling function of the Runner class, and will process the data in data_loaders according to the workflow specified by workflow. At present, MMCV supports two workflows, training and verification. For EpochBasedRunner, the workflow is configured to [('train', 2),('val', 1)]train two epochs first, and then verify one epoch; [('train', 1)]it means only training and no verification. If it is IterBasedRunner, [('train', 2),('val', 1)]it means to train 2 iters first, and then verify one iter. Then getattr(self, mode)the self.train() function and self.val() function will be called according to different modes.

def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
    while self.epoch < self._max_epochs:
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            
            # 如果mode='train', 则调用self.train()函数
            # 如果mode='val', 则调用self.val()函数
            epoch_runner = getattr(self, mode)

            for _ in range(epochs):
                if mode == 'train' and self.epoch >= self._max_epochs:
                    break
                # 运行train()或val()
                epoch_runner(data_loaders[i], **kwargs)

The train() and val() functions
  train() and val()function loop calls run_iter()complete an epoch process. The self.model.train() and self.model.eval() at the beginning of the function actually call the member function of torch.nn.module.Module, and set the current module to training mode or verification mode. In two different modes The operations of layers such as batchnorm and dropout will be different. Then, since the test process does not require gradient return, a decorator is added to the val function @torch.no_grad().

def train(self, data_loader, **kwargs):
	# 将模块设置为训练模式
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(self.data_loader)
    for i, data_batch in enumerate(self.data_loader):
        self.run_iter(data_batch, train_mode=True, **kwargs)
        self._iter += 1

    self._epoch += 1

@torch.no_grad()
def val(self, data_loader, **kwargs):
	# 将模块设置为验证模式
    self.model.eval()
    self.mode = 'val'
    self.data_loader = data_loader
    for i, data_batch in enumerate(self.data_loader):
        self.run_iter(data_batch, train_mode=False)

def run_iter(self, data_batch, train_mode, **kwargs):
    if self.batch_processor is not None:
        outputs = self.batch_processor(self.model, data_batch, train_mode=train_mode, **kwargs)
    elif train_mode:
        outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
    else:
        outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
    
    self.outputs = outputs

  train()val()The core function of and is to call or run_iter()according to the train_mode parameter . These two functions will eventually point to the function of our own model and return the forward reasoning result of the model (usually the Loss value). Between the Runner and our own model, there will be four classes of MMDataParallel, BaseDetector, SingleStageDetector (or TwoStageDetector), and finally call the function of our own model to execute the reasoning process.model.train_step()model.val_step()forward()forward()

  Careful students may ask, why did not see the step of gradient backpropagation optimization from beginning to end? The gradient optimization of MMDetection is implemented through an implemented after_train_iter()Hook with a priority of ABOVE_NORMAL.

@HOOKS.register_module()
class OptimizerHook(Hook):
	def after_train_iter(self, runner):
	    runner.optimizer.zero_grad()
	    runner.outputs['loss'].backward()
	    if self.grad_clip is not None:
	        grad_norm = self.clip_grads(runner.model.parameters())
	        if grad_norm is not None:
	            # Add grad norm to the logger
	            runner.log_buffer.update({
    
    'grad_norm': float(grad_norm)},
	                                     runner.outputs['num_samples'])
	    runner.optimizer.step()

save_checkpoint() function
  The save_checkpoint() function is relatively simple, so I won’t explain too much. Finally, torch.save is called to save the checkpoint into a file in the following format.

checkpoint = {
    
    
			  'meta': dict(),			# 环境信息(比如epoch_num, iter_num)
			  'state_dict': dict(),		# 模型的state_dict()
			  'optimizer': dict())		# 优化器的state_dict()
}

Guess you like

Origin blog.csdn.net/qq_16137569/article/details/121316235