gluoncv与mxnet的model转换

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Chunfengyanyulove/article/details/83623519

前不久,gluoncv进行了更新,发布了gluoncv0.3版本,该版本的一大创新点就是提供了一批在imagenet上精度更高的模型,对于需要使用预训练模型做迁移学习的小伙伴应该是帮助很大的,毕竟高精度可以带来一定效果的提升。

网址如下:https://gluon-cv.mxnet.io/model_zoo/classification.html

可是对于不使用gluoncv,而使用mxnet的小伙伴该如何使用新提供的预训练模型呢?下面就来帮你解答一下:

1、配置环境:

首先当然是得配置环境,需要安装mxnet1.3.0以及gluoncv0.3,如何安装这里就不详述了,可以自行百度。

2、下载gluoncv0.3的模型:

代码如下:

不同的模型可以去上面的网址中寻找,另外preprocess代表下载的模型是否包含预处理层,如果包含预处理层,则为True,由于mxnet是不包含预处理层的,所以这里为None,另外,layout代表模型的输入是channel first的。

import gluoncv
from gluoncv.utils import export_block
net = gluoncv.model_zoo.get_model('DenseNet169', pretrained=True)
export_block('gluoncv0.3-DenseNet169', net, preprocess=None, layout='CHW')

这样你就下载下来了这个模型,就可以使用了

注:这个模型的输入数据需要减均值,除方差,不要忘记了

--------------------------- 华丽的分割线-------------------------------

另一个问题:

第二个问题,我们发现下载下来的模型已经可以使用了,但是细心的你可能会发现,虽然模型是一样的,但是模型的layer层的名字与mxnet提供的名字不一致,这样如果我们在模型上做更多的操作就会出现问题了,那么我们如何将两个模型的layer层的名字也变为一致呢?

思想: 虽然模型的layer name不一样,但是将json文件读取进来后,你会发现layer的顺序是一样的,这样就可以杜宇json文件,然后按照json中layer的顺序一一对应,这样就可以利用这个顺序对模型进行赋值,得到新模型了

代码如下,也很好理解。

import mxnet as mx
import json
import argparse
import os


def main(args):

    ## prepare layer name

    aim_model_arg_name = []
    aim_model_aux_name = []

    src_model_arg_name = []
    src_model_aux_name = []

    aim_json_path = os.path.join(args.model_path, args.aim_model_name + '-symbol.json')
    with open(aim_json_path,'r') as load_f:
        load_dict = json.load(load_f)

    for layer in load_dict['nodes']:
        layername = layer['name']
        if layername[-4:] == 'mean' or layername[-3:] == 'var':
            aim_model_aux_name.append(layername)
        else:
            aim_model_arg_name.append(layername)

    print('aim model param read finished')

    src_json_path = os.path.join(args.model_path, args.src_model_name + '-symbol.json')
    with open('../../src/engine/engine/img_cls_common_0907_stick_small_image/model/gluoncv0.3-resnet-101-symbol.json','r') as load_f:
        load_dict = json.load(load_f)

    for layer in load_dict['nodes']:
        layername = layer['name']
        if layername[-4:] == 'mean' or layername[-3:] == 'var':
            src_model_aux_name.append(layername)
        else:
            src_model_arg_name.append(layername)

        print('src model param read finished')

## start copy

    aim_sym, aim_arg, aim_aux = mx.model.load_checkpoint('../../src/engine/engine/img_cls_common_0907_stick_small_image/model/resnet-101',17)
    src_sym, src_arg, src_aux = mx.model.load_checkpoint('../../src/engine/engine/img_cls_common_0907_stick_small_image/model/gluoncv0.3-resnet-101',17)


    for i in range(len(aim_model_arg_name)):
        if aim_arg.has_key(aim_model_arg_name[i]):
            aim_arg[aim_model_arg_name[i]] = src_arg[src_model_arg_name[i]]

    for i in range(len(aim_model_aux_name)):
        if aim_aux.has_key(aim_model_aux_name[i]):
            aim_aux[aim_model_aux_name[i]] = src_aux[src_model_aux_name[i]]


    model = mx.mod.Module(aim_sym, context=mx.gpu(0))
    model.bind(data_shapes=[('data', (1, 3, 224, 224))], for_training=False)  # bangding
    model.set_params(aim_arg, aim_aux)

    model.save_checkpoint('../../src/engine/engine/img_cls_common_0907_stick_small_image/model/new-resnet',0)

    print('{} model change to {} finished, result model name:{}'.format(args.src_model_name, args.aim_model_name, args.result_model_name))


if __name__ is '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-path', type=str, default=None)
    parser.add_argument('--aim-model-name', type=str, default='resnet-101')
    parser.add_argument('--aim-model-index', type=int, default=0)
    parser.add_argument('--src-model-name', type=str, default=None)
    parser.add_argument('--src-model-index', type=int, default=0)
    parser.add_argument('--result-path', type=str, default=None)
    parser.add_argumnet('--result-model-name', type=str,default='new-model')
    args = parser.parse_args()

    main(args)


如上,是本人的一点实践,如果有更好的方法,欢迎指正

猜你喜欢

转载自blog.csdn.net/Chunfengyanyulove/article/details/83623519