MMDetection のチュートリアルはインターネット上にたくさんあるようですが、体系的ではなく、読んだ後でも MMDetection の使い方がわかりません。ここでも、公式チュートリアルに直接従って、ソース コードと組み合わせて MMDetection を学習することをお勧めします。関連リンクは次のようにまとめられています。
この記事では、MMDetection で独自のアルゴリズムを最初から構築する方法を紹介します。最初の数回のブログは学習過程のメモです MMDetection の原理をソースコードそのものから解析して比較的詳細にまとめました このブログでは MMDetection の利用方法と処理原理をマクロな視点で再整理してみます前月の学習プロセスとしてまとめます。
1. フレームワークの概要
MMDetection は、SenseTime と香港中文大学がターゲット検出タスクのために立ち上げたオープンソース プロジェクトで、Pytorch に基づいた多数のターゲット検出アルゴリズムを実装し、データセット構築、モデル構築、トレーニング戦略のプロセスをカプセル化しています。モジュール呼び出しの方法では、少ないコード量で新しいアルゴリズムを実装でき、コードの再利用率が大幅に向上します。
MMDetection に加えて、MMLab ファミリ全体には、ターゲット追跡タスク用の MMTracking や 3D ターゲット検出タスク用の MMDetection3D などのオープンソース プロジェクトも含まれており、これらはすべて Pytorch と MMCV に基づいています。Pytorch についてはあまり説明の必要はありません。MMCV はコンピュータ ビジョン用の基本ライブラリです。主な機能は、Pytorch に基づいた一般的なトレーニング フレームワークを提供することです。たとえば、よく言及するレジストリ、ランナー、フック、その他の機能はすべて でサポートされています。 MMCV 。さらに、MMCV は、汎用 IO インターフェイス、複数の CNN ネットワーク構造、および高品質の実装を備えた一般的な CUDA オペレーターも提供しますが、ここではこれ以上拡張しません。
2. フレームワークの全体的なプロセス
2.1 パイトーチ
Pytorch を使用して新しいアルゴリズムを構築する場合、通常は次の手順が含まれます。
- データセットの構築: 新しいクラスを作成し、
Dataset
クラスを継承し、__getitem__()
データとタグのロードとトラバーサル機能を実現するメソッドを書き換え、データの前処理プロセスをパイプライン形式で定義します - データ ローダーを構築する: 対応するパラメーターを渡して DataLoader をインスタンス化します。
- モデルの構築: 新規クラスの作成、
Module
クラスの継承、forward()
関数定義モデルの転送処理の書き換え - 損失関数とオプティマイザを定義します。アルゴリズムに従って適切な損失関数とオプティマイザを選択します。
- トレーニングと検証: DataLoader からデータとラベルを周期的に取得し、ネットワーク モデルに送信し、損失を計算し、オプティマイザーを使用してバックプロパゲーションの勾配に従って反復最適化を実行します。
- その他の操作: トレーニング トリック、ログの出力、チェックポイントの保存などの操作をメイン呼び出し関数に任意に挿入できます。
2.2MM検出
Pytorch を使用して新しいアルゴリズムを構築する場合、通常は次の手順が含まれます。
- 登録データ セット:
CustomDataset
これは、元のベースで MMDetection を再カプセル化したものでDataset
、そのメソッドはトレーニング モードとテスト モードに従って__getitem__()
リダイレクトされprepare_train_img()
、機能します。prepare_test_img()
ユーザーがクラスを継承して独自のデータセットを構築する場合、データとラベルの読み込みおよび走査メソッドを定義する関数を書き直すCustomDataset
必要があります。データセットクラスの定義が完了したら、モジュールの登録も行う必要があります。load_annotations()
get_ann_info()
DATASETS.register_module()
- モデルの登録: モデルの構築方法は Pytorch の方法と似ており、新しい
Module
サブクラスを作成してforward()
関数を書き換えます。唯一の違いは、はい、サブクラスBaseModule
ではなくMMDetection を継承する必要があることです。MMLab のモデルはすべてこのクラスから継承する必要があります。また、MMDetectionでは完全なモデルをバックボーン、ネック、ヘッドに分割して管理するため、ユーザーはこのようにアルゴリズムモデルを3つのクラスに分解し、それぞれ使用、モジュール登録を完了する必要があります。Module
BaseModule
Module
BACKBONES.register_module()
NECKS.register_module()
HEADS.register_module()
- 構成ファイルを構築する: 構成ファイルは、アルゴリズムの各コンポーネントの動作パラメーターを構成するために使用され、通常、データセット、モデル、スケジュール、ランタイムの 4 つの部分を含めることができます。対応するモジュールの定義と登録が完了したら、構成ファイルに対応する動作パラメータを設定します。その後、MMDetection は
Registry
クラスを通じて構成ファイルを読み取って解析し、モジュールのインスタンス化を完了します。さらに、構成ファイルは_base_
フィールドを通じて継承関数を実装し、コードの再利用を向上させることができます。 - トレーニングと検証: 各モジュールのコード実装、モジュール登録、構成ファイルの書き込みが完了すると、ユーザーが追加のコードを記述することなく、モデルを使用してトレーニングおよび検証できます
./tools/train.py
。./tools/test.py
2.3 プロセスの比較
MMDetection は、ステップの点で Pytorch のアルゴリズム実装ステップとはかなり異なりますが、基礎となるロジック実装は本質的に Pytorch と同じです。比較のために次の図を参照できます。青色の部分は Pytorch のプロセスを表し、オレンジ色の部分は Pytorch のプロセスを示しますMMDetection プロセス。緑色の部分はアルゴリズム フレームワークとは関係のない一般的なプロセスを示します。
MMDetection のアルゴリズム実装プロセスを開始する前に、まず登録メカニズムとフック メカニズムについて一般的に理解しておく必要があります。すぐに読んで、登録メカニズムとフック メカニズムについて一般的に理解することをお勧めします。第 5 章を読む 登録メカニズムとフックメカニズムの詳細を振り返ると、より深く理解できます。
3. 登録メカニズム
3.1 レジストリクラス
MMDetection は、MMCV の下流プロジェクトとして、MMCV のモジュール管理メソッドである登録メカニズムを継承しています。簡単に言うと、登録メカニズムは複数のルックアップ テーブルを維持することであり、キーはモジュールの名前、値はモジュールのハンドルです。各ルックアップ テーブルは、同様の機能を持つ異なるモジュールのバッチを管理します。新しいモジュールを作成するたびに、key-value
モジュールが実現する機能に応じて、対応するクエリ ペアを対応するクエリ テーブルに保存する必要があります。この保存処理を「登録」と呼びます。モジュールを呼び出したいときは、モジュール名に従ってルックアップテーブルから対応するモジュールハンドルを見つけるだけで、モジュールの初期化やメソッド呼び出しなどの操作を完了できます。MMCV は、Registry
クラスを介して文字列 (キー) からクラス (値) へのマッピングを実装します。
レジストリのコンストラクタは以下の通りです 変数はself._module_dict
前述の「ルックアップテーブル」です 登録されたモジュールはこのディクショナリ型の変数に格納されます レジストリインスタンスを作成するということは、新しいルックアップテーブルを作成することになります さらに、レジストリは継承メカニズムもサポートしています。
from mmcv.utils import Registry
class Registry:
# 构造函数
def __init__(self, name, build_func=None, parent=None, scope=None):
# 注册器的名称
self._name = name
# 使用module_dict管理字符串到类的映射
self._module_dict = dict()
# 使用children管理注册器的子类
self._children = dict()
# build_func按照如下优先级初始化:
# 1. build_func: 优先使用指定的函数
# 2. parent.build_func: 其次使用父类的build_func
# 3. build_from_cfg: 默认从config dict中实例化对象
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
# 设置父类-子类的从属关系
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
モジュールの登録はレジストリのメンバー関数を通じてregister_module()
実現され、register_module()
内部的には別のプライベート関数が呼び出されます_register_module()
。モジュール登録のコア機能は実際には_register_module()
レジストリ内で実現されます。中心となるコードも非常に単純で、受信したmodule_name
合計を辞書にmodule_class
保存するだけです。self._module_dict
def _register_module(self, module_class, module_name=None, force=False):
# 如果未指定模块名称则使用默认名称
if module_name is None:
module_name = module_class.__name__
# 为了支持在nn.Sequentail中构建pytorch模块, module_name为list形式
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
# 如果force=False, 则不允许注册相同名称的模块
# 如果force=True, 则用后一次的注册覆盖前一次
if not force and name in self._module_dict:
raise KeyError(f'{
name} is already registered in {
self.name}')
# 将当前注册的模块加入到查询表中
self._module_dict[name] = module_class
文字列を通じてモジュールのハンドルを取得した後、self.build_func
関数ハンドルを通じてモジュールをインスタンス化できます。build_func
手動で指定することも、親クラスから継承することもでき、一般的に、このbuild_from_cfg()
関数はデフォルトで使用されます。つまり、cfg
モジュールは構成パラメーターで初期化されます。構成パラメーターはcfg
辞書であり、その中のフィールドはtype
モジュール名の文字列であり、他のフィールドはモジュール コンストラクターの入力パラメーターに対応します。
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
# 将cfg以外的外部传入参数也合并到args中
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
# 获取模块名称
obj_type = args.pop('type')
if isinstance(obj_type, str):
# get函数返回registry._module_dict中obj_type对应的模块句柄
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(f'{
obj_type} is not in the {
registry.name} registry')
elif inspect.isclass(obj_type):
# type值是模块本身
obj_cls = obj_type
else:
raise TypeError(f'type must be a str or valid type, but got {
type(obj_type)}')
# 模块初始化, 返回模块实例
try:
return obj_cls(**args)
except Exception as e:
raise type(e)(f'{
obj_cls.__name__}: {
e}')
registry
パラメータが現在のレジストリ自体を指す必要があること を考慮すると、通常はbuild()
代わりに Registry クラスのメソッドを呼び出しますself.build_func
。
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
以下は、ネットワーク モデルの登録と呼び出しのプロセスをシミュレートする小さな例です。Registry オブジェクトを出力するとき、実際にはself._module_dict
そのオブジェクト内の値が出力されることに注意してください。
# 实例化一个注册器用来管理模型
MODELS = Registry('myModels')
# 方式1: 在类的创建过程中, 使用函数装饰器进行注册(推荐)
@MODELS.register_module()
class ResNet(object):
def __init__(self, depth):
self.depth = depth
print('Initialize ResNet{}'.format(depth))
# 方式2: 完成类的创建后, 再显式调用register_module进行注册(不推荐)
class FPN(object):
def __init__(self, in_channel):
self.in_channel= in_channel
print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)
print(MODELS)
""" 打印结果为:
Registry(name=myModels, items={'ResNet': <class '__main__.ResNet'>, 'FPN': <class '__main__.FPN'>})
"""
# 配置参数, 一般cfg从配置文件中获取
backbone_cfg = dict(type='ResNet', depth=101)
neck_cfg = dict(type='FPN', in_channel=256)
# 实例化模型(将配置参数传给模型的构造函数), 得到实例化对象
my_backbone = MODELS.build(backbone_cfg)
my_neck = MODELS.build(neck_cfg)
print(my_backbone, my_neck)
""" 打印结果为:
Initialize ResNet101
Initialize FPN256
<__main__.ResNet object at 0x000001E68E99E198> <__main__.FPN object at 0x000001E695044B38>
"""
3.2 登録メカニズムの概要
登録メカニズムはモジュール管理の手段です。モジュールは、異なるモジュール機能に従ってグループ化され、管理されます。各グループはクエリ テーブルによって維持されます。クエリ テーブルには、モジュール名 (文字列) とモジュール自体 (それ自体) の間のマッピング関係が記録されます。 ) 、マッピング関係をクエリ テーブルに記録するプロセスは「登録」と呼ばれます。モジュールが登録されると、モジュール名に従って特定のモジュール ハンドルに簡単にインデックスを付けることができ、通常のプログラム フローに従ってモジュールを初期化して使用できます。モジュールの登録と使用は、次の 5 つのステップで構成されます。
- 新しいクラスを作成してカスタム関数を実装する
- このクラスを対応するクエリテーブルに登録します(
register_module
) - 構成ファイルでモジュールの初期化パラメータを指定します
- ビルド関数を使用してモジュールをインスタンス化します (
build_from_cfg
) - このインスタンス オブジェクトを使用して関数関数を実行します
4. フック機構
4.1 フッククラス
MMDetection のアルゴリズム プロセス全体はブラック ボックスのようなものです。入力 (構成ファイル) が与えられると、ブラック ボックスはアルゴリズムの結果を吐き出します。プロセス全体は高度にカプセル化されており、コードを手動で記述する必要はほとんどありません。しかし、アルゴリズム実行プロセスにカスタム操作を追加するにはどうすればよいでしょうか? これがフック機構の役割です。
簡単に言えば、フックは、プログラム内の事前定義された位置で事前定義された関数を実行できるトリガーとして理解できます。MMCV では、アルゴリズムのライフサイクルに従ってユーザー定義関数を挿入できる 6 つのサイトが事前に定義されており、次の図に示すように、ユーザーは各サイトに任意の数の関数操作を自由に挿入できます。
これらの 6 つの位置は、基本的にカスタム操作が表示される位置をカバーしています。MMCV には、一般的に使用されるフックがいくつか実装されています。デフォルトのフックでは、ユーザーが自分自身を登録する必要はなく、対応するパラメーターは構成ファイルを通じて構成できます。カスタム フックでは、ユーザーが次のことを行う必要があります。設定ファイルに設定し、手動設定フィールドにcustom_hooks
登録します。
Hook
クラス自体にはコードがほとんどなく、事前定義された場所にインターフェイス関数が提供されるだけなので、カスタム フックはクラスを継承しHook
、必要に応じて対応するインターフェイス関数を書き直す必要があります。たとえば、チェックポイント保存操作は通常、各反復またはエポックの後に発生するため、after_train_iter
と を書き直す必要がありますafter_train_epoch
。
class Hook:
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
@HOOKS.register_module()
class CheckpointHook(Hook):
def __init__(self,
interval=-1,
by_epoch=True,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
**kwargs):
...
def after_train_iter(self, runner):
...
def after_train_epoch(self, runner):
...
他のモジュールとは異なり、Hook を定義した後 (レジスタに登録した後HOOKS
)、使用する前に Runner に登録する必要があります。最初の登録は、HOOKS
プログラムがフック名に従って対応するモジュールを見つけられるようにするものであり、ランナーへの 2 番目の登録は、プログラムが事前定義された位置まで実行されるときに対応する関数を呼び出すことです。
Runner は MMCV がトレーニング プロセスを管理するために使用するクラスです。内部でリスト型変数を保持します。self._hooks
トレーニング プロセス中に呼び出されるすべての Hookインスタンス オブジェクトを優先順位に従って追加する必要がありますself._hooks
。このプロセスはRunner.register_hook()
関数によって実現されます。MMCV では、いくつかの優先度レベルが事前に定義されています。数値が小さいほど、優先度が高くなります。デフォルトの評価方法が粒度が高すぎると感じる場合は、0 ~ 100 の整数を直接渡して細かく分割することもできます。
def register_hook(self, hook, priority='NORMAL'):
"""预定义优先级
+--------------+------------+
| Level | Value |
+==============+============+
| HIGHEST | 0 |
+--------------+------------+
| VERY_HIGH | 10 |
+--------------+------------+
| HIGH | 30 |
+--------------+------------+
| ABOVE_NORMAL | 40 |
+--------------+------------+
| NORMAL | 50 |
+--------------+------------+
| BELOW_NORMAL | 60 |
+--------------+------------+
| LOW | 70 |
+--------------+------------+
| VERY_LOW | 90 |
+--------------+------------+
| LOWEST | 100 |
+--------------+------------+
"""
hook.priority = priority
# 插入法排序将Hooks按照priority大小升序排列
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
フック インスタンスがそれに追加されるとself._hooks
、事前定義された場所でフック インスタンスを呼び出してcall_hook()
、各フック インスタンスの対応するメソッドを呼び出すことができます。call_hook()
コールバック関数と呼ばれます。
# 开始运行时调用
self.call_hook('after_train_epoch')
while self.epoch < self._max_epochs:
# 开始 epoch 迭代前调用
self.call_hook('before_train_epoch')
for i, data_batch in enumerate(self.data_loader):
# 开始 iter 迭代前调用
self.call_hook('before_train_iter')
self.model.train_step()
# 经过一次迭代后调用
self.call_hook('after_train_iter')
# 经过一个 epoch 迭代后调用
self.call_hook('after_train_epoch')
# 运行完成前调用
self.call_hook('after_train_epoch')
呼び出すと、 Hook インスタンス内のすべての Hook インスタンスcall_hook()
を走査しself._hooks
、fn_name
Hook インスタンスの指定されたメンバー関数を呼び出します。たとえばfn_name='before_train_epoch'
、call_hook()
すべてのフックbefore_train_epoch()
関数が 1 つずつ呼び出されます。そしてself._hooks
優先度に従ってソートされているため、call_hook()
優先度の高いフックメソッドが最初に呼び出されます。
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
4.2 フック機構の概要
フックはプログラム内の固定位置に設定されたトリガーです。プログラムが事前に設定された位置まで実行されると、ブレークポイントがトリガーされ、フック関数のフローが実行され、その後ブレークポイントの位置に戻ってコードの実行が継続されます。メインプロセス。フックの実装は 5 つのステップで構成されます。
- Hook基本クラスを継承するクラスを定義する
- カスタムフックの関数に従って、フック基本クラス内の対応する関数を選択的に書き換えます。
- カスタムフックモジュールをHOOKSクエリテーブルに登録します(
register_module
) - Hook モジュールをインスタンス化して Runner に登録します (
register_hook
) - コールバック関数を使用して、書き換えられたフック関数を呼び出します (
call_hook
)
5. アルゴリズム実装プロセス
セクション 2.2 では、MMDetection を使用して新しいアルゴリズムを実装するには、登録データ セット、登録モデル、構成ファイルの構築、トレーニング/検証の 4 つのステップが含まれると述べました。MMDetection のアルゴリズム実装プロセスを理解するには、Config、Registry、Runner、Hook の 4 つのクラスを完全に理解する必要があります。
5.1 登録データセット
CustomDataset
独自のデータセットを定義する場合は、継承された新しいDataset クラスを作成し、load_annotations()
関数と関数を書き直す 必要がありますget_ann_info()
。公式ドキュメントには、ユーザーが使用したい場合はCustomDataset
、既存のデータセットを MMDetection 互換形式 (COCO 形式または中間形式) に変換する必要があると記載されています。しかし、基礎となるコードを調べたところ、データ形式が達成したものと一致する限り、そのような制限はありませんでしたload_annotations()
。get_ann_info()
"""
中间数据格式:
[
{
'filename': 'a.jpg', # 图片路径
'width': 1280, # 图片尺寸
'height': 720,
'ann': { # 标注信息
'bboxes': <np.ndarray, float32> (n, 4), # 标注框坐标(x1, y1, x2, y2)
'labels': <np.ndarray, int64> (n, ), # 标注框类别
'bboxes_ignore': <np.ndarray, float32> (k, 4), # 不关注的标注框坐标(可选)
'labels_ignore': <np.ndarray, int64> (k, ) # 不关注的标注框类别(可选)
}
},
...
]
"""
class CustomDataset(Dataset):
CLASSES = None
def __init__(self,
ann_file, # 文件路径
pipeline, # 数据预处理pipeline
classes=None, # 检测类别
data_root=None, # 文件根路径
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False, # 为True的话将不会加载标注信息
filter_empty_gt=True): # 为True的话将会过滤没有标注框的图像(只在test_mode=False的条件下有效)
self.ann_file = ann_file
self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.proposal_file = proposal_file
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
self.CLASSES = self.get_classes(classes)
# 调用load_annotations函数加载样本和标签
self.data_infos = self.load_annotations(self.ann_file)
# 用户可以通过重写_filter_imgs()函数在训练过程中实现自定义的样本过滤功能
if not test_mode:
valid_inds = self._filter_imgs()
self.data_infos = [self.data_infos[i] for i in valid_inds]
# 根据pipeline对样本进行预处理
self.pipeline = Compose(pipeline)
Pytorch のトラバーサルは関数Dataset
を書き換えることで__getitem__()
実現しますが、MMDetection はMMDetection のサブクラスCustomDataset
であるにも関わらず、トレーニングモードやテストモードでのデータ管理を容易にするために関数を書き換える必要がありません。関数は、現在の実行モードまたはに応じて呼び出すことができます。この 2 つの違いは、トレーニング ラベルをロードするかどうかです。したがって、 sum関数を書き直すだけでよく、残りは MMDetection に任せます。Dataset
__getitem__()
__getitem__()
prepare_train_img()
prepare_test_img()
load_annotations()
get_ann_info()
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_img(idx)
else:
return self.prepare_train_img(idx)
# 返回预处理后的训练样本及标签
def prepare_train_img(self, idx):
img_info = self.data_infos[idx]
# 调用get_ann_info获取训练标签
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
return self.pipeline(results)
# 返回预处理后的测试样本
def prepare_test_img(self, idx):
img_info = self.data_infos[idx]
results = dict(img_info=img_info)
return self.pipeline(results)
カスタム Dataset クラスを完了したら、@DATASETS.register_module()
現在のモジュールを DATASETS テーブルに追加することを忘れないでください。
5.2 モデルの登録
ネットワーク モデルの定義は比較的単純で、Pytorch との違いは 2 つだけです。
- 継承された親クラスが次から
Module
変更されましたBaseModule
- モデルは、背骨、首、頭の構造に応じて 3 つの部分に分解する必要があり、それぞれ
BACKBONES
、NECKS
、で定義および登録されますHEADS
。
5.3 ビルド設定ファイル
セクション 2.2 で述べたように、MMDetection フレームワークでは、反復トレーニング/テスト プロセス用に追加のコードを実装する必要はなく、既成の train.py または test.py を実行するだけで済みます。しかし、MMDetection はどのモジュールが必要かをどのようにして知るのでしょうか? これが設定ファイルの役割です。
5.3.1 設定ファイルの構成
構成ファイルは、一連の変数定義で構成されるテキスト ファイルです。dict
変数のタイプは各モジュールを表し、dict
変数type
にはモジュール名を表すフィールドと、モジュール コンストラクターのパラメーターに対応するその他のフィールドが含まれている必要があります。モジュールの初期化 (この記事の第 3 章の関数を参照build_from_cfg()
)。モジュールは登録する必要があります。登録しない場合、後続の MMDetection はtype
値に基づいて対応するモジュールを見つけることができません。タイプの変数に加えて、構成ファイルはdict
他のタイプにすることもできます。これは通常dict
、次のような補助変数によって定義される中間変数です。
test_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='PadMultiViewImage', size_divisor=32)
]
evaluation = dict(interval=2, pipeline=test_pipeline)
構成ファイルは、変数を通じて実装される継承操作もサポートしています_base_
。継承する設定ファイルのパスを格納する型変数_base_
です。構成ファイルを解析する場合、ファイル パーサーはすべての構成ファイルをlist
再帰的に解析します (他の構成ファイルには変数が含まれる場合もあります)。_base_
バックアップされる構成ファイルは、データセット、モデル、トレーニング戦略 (スケジュール)、およびデフォルトのランタイム構成 (default_runtime) に対応する次の 4 つのファイルを継承します。
_base_ = [
'mmdetection/configs/_base_/models/fast_rcnn_r50_fpn.py', # models
'mmdetection/configs/_base_/datasets/coco_detection.py', # datasets
'mmdetection/configs/_base_/schedules/schedule_1x.py', # schedules
'mmdetection/configs/_base_/default_runtime.py', # defualt_runtime
]
上記の 4 つの基本構成ファイルを継承する構成ファイルを印刷すると、次の内容が表示されます。これは、完全な構成ファイルに含める必要がある構成情報でもあります。もちろん、カスタム構成情報を任意に追加することもできます。したがって、通常、新しい構成ファイルを作成するときは、これら 4 つの基本構成ファイルを継承し、これに基づいて的を絞った調整を行います。
# 1. 模型配置(models) =========================================
model = dict(
type='FastRCNN', # 模型名称是FastRCNN
backbone=dict( # BackBone是ResNet
type='ResNet',
...,
),
neck=dict( # Neck是FPN
type='FPN',
...,
),
roi_head=dict( # Head是StandardRoIHead
type='StandardRoIHead',
...,
loss_cls=dict(...), # 分类损失函数
loss_bbox=dict(...), # 回归损失函数
),
train_cfg=dict( # 训练参数配置
assigner=dict(...), # BBox Assigner
sampler=dict(...), # BBox Sampler
...
),
test_cfg =dict( # 测试参数配置
nms=dict(...), # NMS后处理
...,
)
)
# 2. 数据集配置(datasets) =========================================
dataset_type = '...' # 数据集名称
data_root = '...' # 数据集根目录
img_norm_cfg = dict(...) # 图像归一化参数
train_pipeline = [ # 训练数据处理Pipeline
...,
]
test_pipeline = [...] # 测试数据处理Pipeline
data = dict(
samples_per_gpu=2, # batch_size
workers_per_gpu=2, # GPU数量
train=dict( # 训练集配置
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json', # 标注问加你
img_prefix=data_root + 'train2017/', # 图像前缀
pipline=trian_pipline, # 数据预处理pipeline
),
val=dict( # 验证集配置
...,
pipline=test_pipline,
),
test=dict( # 测试集配置
...,
pipline=test_pipline,
)
)
# 3. 训练策略配置(schedules) =========================================
evaluation = dict(interval=1, metric='bbox')
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
# 4. 运行配置(runtime) =========================================
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
さらに、ユーザー定義モジュールのインポートに使用されるオプションの構成パラメータがいくつかあります。構成ファイル パーサーがこのフィールドを解決すると、フィールドに含まれるモジュールをプログラムにインポートする関数custom_imports
が呼び出されます。import_modules_from_strings()
imports
custom_imports = dict(imports=['os.path', 'numpy'], # list类型, 需要导入的模块名称
allow_failed_imports=False) # 如果设为True, 导入失败时会返回None而不是报错
5.3.2 設定ファイルの変更
構成ファイルを変更する場合には、次の 2 つの状況が考えられます。
- 既存の辞書のパラメータを変更する: 対応するパラメータを直接書き換えます
- 元の辞書のパラメータをすべて削除し、新しいパラメータのセットに置き換える必要があります:
_delete_=True
addfields
学習率の変更とオプティマイザーの置き換えを例として、次の 2 つのケースで構成ファイルを変更する方法を説明します。
# 从_base_中继承的原始优化器
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
# 修改学习率
optimizer = dict(lr=0.001)
# 修改后optimizer变成
optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)
# 将原来的SGD替换成AdamW
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001)
# 替换后optimizer变成
optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001)
5.3.3 設定ファイルの分析
設定ファイルの解析は実際には train.py と test.py が行うべきことですが、ここでは設定ファイルの構築と一緒に説明するので、ロジックはよりスムーズになります。
通常、構成ファイルを管理するには Config クラスを使用します。を使用しConfig.fromfile(filename)
て構成ファイルを読み取り (dict を直接渡すこともできます)、Config クラス インスタンス cfg を返します。その後、print(cfg.pretty_text)
によって構成ファイル情報を出力したり、cfg.dump(filepath)
によって構成ファイル情報を保存したりできます。
from mmcv import Config
cfg = Config.fromfile('../configs/test_config.py')
fromfile()
関数のソースコードは次のとおりで、そのコア関数は です_file2dict()
。_file2dict()
テキストの順序に従って、設定ファイルを key = value の形式に従って解析し、cfg_dict
という名前の辞書を取得します。フィールドがある場合は、含まれるファイル パスごと_base_
に関数が_base_
再度呼び出され_file2dict()
、設定が取得されますファイルに含まれるパラメータが に追加され、設定ファイルの継承cfg_dict
機能が実装されます。異なるファイルに含まれるキー値は内部的に検証され、異なる基本構成ファイルで重複したキー値は許可されないことに注意してください。そうしないと、Config はどの構成ファイルを標準として採用するかを認識できなくなります。_file2dict()
_base_
def fromfile(filename,
use_predefined_variables=True,
import_custom_modules=True):
cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables)
# import_modules_from_strings()是根据字符串列表导入对应的模块
if import_custom_modules and cfg_dict.get('custom_imports', None):
import_modules_from_strings(**cfg_dict['custom_imports'])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
呼び出しと解析_file2dict()
によって得られる形式cfg_dict
は以下の通りで、設定ファイル内のすべてのテキスト情報が変数に変換され、辞書型に格納されます。
他に追加する必要がある点が 2 つあり、1 つは Config オブジェクトを構築するときに、Python のデータ型が処理用の型dict
に変換されることです。これは、サードパーティのライブラリ addictのサブクラス(python のサブクラスでもあります) です。これは、python のネイティブ型がアクセス メソッドをサポートしていないためです。特に、複数の dict レイヤーが内部にネストされている場合、key のアクセス メソッドが使用されている場合、コードの書き込みは非常に非効率であり、クラスは書き換えによってアクセス メソッドを実装します。したがって、継承されたものは、ディクショナリ内の各メンバー値へのアクセスもサポートします。ConfigDict
ConfigDict
Dict
Dict
dict
dict
.属性
dict
Dict
__getattr__()
.属性
Dict
ConfigDict
.属性
from mmcv import ConfigDict
model = ConfigDict(dict(backbone=dict(type='ResNet', depth=50)))
print(model.backbone.type) # 输出 'ResNet'
次に、設定ファイル名の小数点に対応するため、_file2dict()
Cドライブに一時フォルダを作成して動作させますが、Cドライブにアクセス権設定がある場合、エラーが表示される場合がありますが、この問題は発生するだけです。 Windows システム上で。
5.3.4 設定ファイルの概要
簡単に説明すると、構成ファイルはdict
複数の変数を含むテキスト ファイルであり、各dict
変数は特定のモジュールに対応し (モジュールは登録されている必要があります)、dict
必ずフィールドがありtype
、その他のフィールドはモジュールの構築パラメータに対応します。呼び出し関数によってbuild()
モジュールがインスタンス化されると、type
文字列の値に従って対応するモジュール ハンドルがルックアップ テーブルから検索され、dict
その中の他のフィールドの値がモジュールを初期化するための構築パラメータとして使用されます。
5.4 トレーニングとテスト
MMDetection を使用したアルゴリズムの実装は 4 つのステップで構成されます。最初と 2 番目のステップでは、データ セットとモデルを登録して基本モジュール (データ ストリームとモデル) を構築し、構成ファイルを構築する 3 番目のステップでは、必要なモジュールとモジュール入力を指定します。パラメータ、次の 4 番目のステップでは、構成ファイルに従って事前定義モジュールを 1 つずつ抽出し、指定された入力パラメーターを渡し、アルゴリズム プロセスに従ってそれらを順番につなぎ合わせます。
5.4.1 train.py ファイル
公式コードを見てみましょうtrain.py
(コア関数コードのみを保持します)。それから、より早く理解できるように、MMDetection がランナーとフックを使用してトレーニング プロセス全体をスケジュールする方法を紹介します。
train.py
メインの呼び出し関数は 4 つのことを行います。1 つは、Config クラスを使用して 3 番目のステップで構築した構成ファイルを解析し、次にモデルとデータ セットを初期化し、最後にモデルとデータ セットを関数に渡して開始できるようにすることです。train_detector()
トレーニングのプロセス。
def main():
# Step1: 解析配置文件, args.config是配置文件路径(如何解析配置文件可以参考本文4.3.3节)
cfg = Config.fromfile(args.config)
# Step2: 初始化模型, 函数内部调用的是DETECTORS.build(cfg)
model = build_detector(cfg.model)
# 初始化模型权重
model.init_weights()
# Step3: 初始化训练集和验证集, 函数内部调用build_from_cfg(cfg, DATASETS), 等价于DATASETS.build(cfg)
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline # 验证集在训练过程中使用train pipeline而不是test pipeline
datasets.append(build_dataset(val_dataset))
# Step4: 传入模型和数据集, 准备开始训练模型
train_detector(model, datasets, cfg)
train_detector()
この関数は主にデータローダーを構築し、オプティマイザー、ランナー、フックを初期化し、最後にrunner.runを呼び出して正式な反復トレーニングプロセスを開始します。これにはランナーの概念が含まれますが、ここでは拡張しません。ランナーもモジュールであり、モデルの反復トレーニングを担当することだけを理解する必要があります。
def train_detector(model, dataset, cfg):
# 获取Runner类型, EpochBasedRunner或IterBasedRuner
runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner['type']
# Step1: 获取dataloader, 因为dataset列表里包含了训练集和验证集, 所以使用for循环的方式构建dataloader
# build_dataloader()会用DataLoader类进行dataloader的初始化
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu, # batch_size
runner_type=runner_type) for ds in dataset
]
# Step2: 封装模型, 为了进行分布式训练
model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
# Step3: 初始化优化器
optimizer = build_optimizer(model, cfg.optimizer)
# Step4: 初始化Runner
runner = build_runner(
cfg.runner,
default_args=dict(model=model, optimizer=optimizer)
# Step5: 注册默认Hook(注册到runner._hooks列表中)
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
# Step6: 注册自定义Hook(注册到runner._hooks列表中)
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
for hook_cfg in cfg.custom_hooks:
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
# Step7: 开始训练流程
if cfg.resume_from:
# 恢复检查点
runner.resume(cfg.resume_from)
elif cfg.load_from:
# 加载预训练模型
runner.load_checkpoint(cfg.load_from)
# 调用run()方法, 开始迭代过程
runner.run(data_loaders, cfg.workflow)
公式train.py
ドキュメントは非常に長いですが、コア コードは実際には Pytorch でよく知られているすべての操作です。train.py プロセス全体を以下の図に示します。
- まず、受信した構成ファイルを解析し、構成ファイル内の各モジュールをインスタンス化します。
- 次に、データセット構造を使用します
data_loader
。このモデルでは、主に後続の分散トレーニングのために、カプセル化の層に MMDataParallel が使用されます。 - 次に、data_loader とオプティマイザーを使用して Runner クラス オブジェクトを初期化します
runner
。 - 登録トレーニングプロセス中に使用する必要があるフック
- 反復トレーニング用の構成ファイルで指定されたワークフローに従って、
workflow
runner.run() 関数を実行します。
以下は、runner.run() 関数の内部の概要です。
5.4.2 ランナークラス
Runner はEpochBasedRunnerとIterBasedRunnerに分かれており、その名の通り、前者は epoch 形式で処理を管理し、後者は iter 形式で処理を管理し、どちらも BaseRunner のサブクラスです。EpochBasedRunner と IterBasedRunner 自体はコンストラクターをオーバーライドせず、BaseRunner のコンストラクターを直接継承します。
class BaseRunner(metaclass=ABCMeta):
def __init__(self,
model, # [torch.nn.Module] 要运行的模型
batch_processor=None, # 该参数一般不使用
optimizer=None, # [torch.optim.Optimizer] 优化器, 可以是一个也可以是一组通过dict配置的优化器
work_dir=None, # [str] 保存检查点和Log的目录
logger=None, # [logging.Logger] 训练中使用的日志记录器
meta=None, # [dict] 一些信息, 这些信息会在logger hook中记录
max_iters=None, # [int] 训练epoch数
max_epochs=None): # [int] 训练迭代次数
BaseRunner のサブクラスは、 Runner のコアメソッドでもあるrun()
、train()
、val()
およびの 4 つのメソッドを実装する必要があります。次に、EpochBasedRunner クラスを例として、これら 4 つの関数を詳細に分析します。save_checkpoint()
run() 関数
run() は Runner クラスの呼び出し関数であり、workflow で指定されたワークフローに従って data_loaders 内のデータを処理します。現在、MMCV はトレーニングと検証の 2 つのワークフローをサポートしています。EpochBasedRunner の場合、ワークフローは[('train', 2),('val', 1)]
最初に 2 つのエポックをトレーニングし、次に 1 つのエポックを検証するように構成されています。[('train', 1)]
これはトレーニングのみを意味し、検証は行いません。IterBasedRunner の場合、[('train', 2),('val', 1)]
最初に 2 つの iter をトレーニングし、次に 1 つの iter を検証することを意味します。次に、getattr(self, mode)
self.train() 関数と self.val() 関数が異なるモードに従って呼び出されます。
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
# 如果mode='train', 则调用self.train()函数
# 如果mode='val', 则调用self.val()函数
epoch_runner = getattr(self, mode)
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
# 运行train()或val()
epoch_runner(data_loaders[i], **kwargs)
train() および val() 関数
train()
とval()
関数ループ呼び出しにより、run_iter()
エポック プロセスが完了します。関数先頭の self.model.train() と self.model.eval() は、実際に torch.nn.module.Module のメンバー関数を呼び出し、現在のモジュールをトレーニング モードまたは検証モードに設定します。異なるモード バッチノルムやドロップアウトなどのレイヤーの動作が異なります。次に、テスト処理では勾配リターンが必要ないため、val 関数にデコレータを追加します@torch.no_grad()
。
def train(self, data_loader, **kwargs):
# 将模块设置为训练模式
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
for i, data_batch in enumerate(self.data_loader):
self.run_iter(data_batch, train_mode=True, **kwargs)
self._iter += 1
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
# 将模块设置为验证模式
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
for i, data_batch in enumerate(self.data_loader):
self.run_iter(data_batch, train_mode=False)
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
self.outputs = outputs
train()
andval()
のコア機能は、run_iter()
train_mode パラメータに従ってmodel.train_step()
orを呼び出すことですmodel.val_step()
。これら 2 つの関数は、最終的に独自のモデルの関数を指しforward()
、モデルの順推論結果 (通常は損失値) を返します。Runner と独自のモデルの間には、MMDataParallel、BaseDetector、SingleStageDetector (または TwoStageDetector) の 4 つのクラスがあり、最後に独自のモデルのforward()
関数を呼び出して推論プロセスを実行します。
注意深い学生は、なぜ勾配逆伝播最適化のステップを最初から最後まで見なかったのかと疑問に思うかもしれません。MMDetection の勾配の最適化は、実装されたafter_train_iter()
フックを通じて実装され、その優先順位は ABOVE_NORMAL です。
@HOOKS.register_module()
class OptimizerHook(Hook):
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
if self.grad_clip is not None:
grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update({
'grad_norm': float(grad_norm)},
runner.outputs['num_samples'])
runner.optimizer.step()
save_checkpoint() 関数
save_checkpoint() 関数は比較的単純なので説明は省略しますが、最後に torch.save を呼び出してチェックポイントを以下の形式でファイルに保存します。
checkpoint = {
'meta': dict(), # 环境信息(比如epoch_num, iter_num)
'state_dict': dict(), # 模型的state_dict()
'optimizer': dict()) # 优化器的state_dict()
}