【mmcls】mmdet中使用mmcls的网络及预训练模型

        mmcls现在叫mmpretrain,以前叫mmclassification,这里为了统一称为mmcls。在基于MM框架的下游任务,例如检测(mmdetection)中可以使用mmcls中的backbone进行特征提取,但这就需要知道网络的参数以及输出特征的维度。本文简单介绍了在mmdetection中使用mmcls中backbone的方法。mmdetection中需要配置backbone、模型权重及neck的特征维度等信息。

1 查找mmcls预训练模型

        查找mmcls支持的网络的方法有多种:

  1. 在mmpretrain的README中;
  2. 在modelzoo种查找模型库统计 — MMClassification 1.0.0rc6 文档 (mmpretrain.readthedocs.io)
  3. 直接看repo的configs目录下的列表

2 获取网络参数(配置)及预训练权重

        找到网络后还需要找到网络参数及预训练权重。以replknet为例,获取网络参数可以直接看mmpretrain/configs/replknet中的配置文件,例如replknet-31B_32xb64_in1k.py,但配置文件可能并没有直接写模型配置信息,而是依赖其他配置文件,如下图中的replknet-31B_in1k.py

         继续找到上述配置文件,可以看到网络配置:

        预训练权重可以在mmpretrain/configs/replknet下的README中找到,例如:

 https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k_20221118-54ed5c46.pth

        预训练权重也可以在modelzoo中查找:

3 获取特征输出维度

        首先在modelzoo中查到已有模型的名称,然后使用mmcls.get_model获取模型,输出指定层的特征维度。

import torch
from mmcls import get_model, inference_model

inputs = torch.rand(16, 3, 224, 224)

# 构建模型
model_name = 'replknet-31B_in21k-pre_3rdparty_in1k'
model = get_model(model_name, pretrained=False, backbone=dict(out_indices=(0, 1, 2, 3)))
# model = get_model(model_name, pretrained=False, backbone=dict(out_scales=(0, 1, 2, 3)))  # mvitv2

feats = model.extract_feat(inputs)
for feat in feats:
    print(feat.shape)

        可以看到输出为 [128, 256, 512, 1024]:

torch.Size([16, 128])
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])

4 mmdetection中使用

        在mmdetection中修改配置文件中backbone,预训练权重和neck中的in_channels等信息。同时应该注意网络的优化器配置的参数。

checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k-384px_20221118-76c92b24.pth'  # noqa

model = dict(
    backbone=dict(
        _delete_=True,
        type='mmcls.RepLKNet',
        arch='31B',
        out_indices=[0, 1, 2, 3],
        init_cfg=dict(
            type='Pretrained', checkpoint=checkpoint_file,
            prefix='backbone.')),
    neck=dict(
        _delete_=True,
        type='mmdet.FPN',
        in_channels=[128, 256, 512, 1024],
        out_channels=256,
        num_outs=5))

猜你喜欢

转载自blog.csdn.net/dou3516/article/details/131224282