MXNet学习 (1) :加载预训练模型

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/wwwhp/article/details/84556909
  • 首先在MXNet的model zoo下载对应的模型描述文件以及模型参数文件:
    • vgg16:对应vgg16.json vgg16-0000.params
    • resnet50:对应resnet50.json resnet50-0000.params
  • 加载网络结构设置网络运行runtime context:
import mxnet as mx
import numpy as np
from collections import namedtuple

net_name = 'vgg16'
img_name = 'dog.jpg'

# imagenet 图像预处理
def load_image(img_name):
    img = mx.image.imread(img_name)
    img = mx.image.imresize(img, 224, 224)
    img = img.transpose((2,0,1)) # hwc->chw
    img = img.expand_dims(axis=0) # chw->1chw
    img = img.astype('float32')
    return img

# 加载 mxnet symbol
sym, arg, aux = mx.model.load_checkpoint(net_name, 0) #net_name代表加载网络name,第二个参数代表Epoch num
# 设置 runtime context
ctx = mx.cpu() || ctx = mx.gpu()
  • 构造module用于执行symbol得到结果
mod = mx.mod.Module(symbol=sym, context=ctx)
mod.bind(for_training=False, data_shape[('data', (1, 3, 224, 224))]) # 为输入数据分配内存
mod.set_params(arg, aux) # 加载模型参数
Batch = namedtuple('Batch', ['data'])
img = load_image(img_name)
mod.forward(Batch([img])) # 做简单的inference
prob = mod.get_outputs()[0].asnumpy
prob = np.squeeze(prob)
a = np.argsort(prob)[::-1] # 得到分类网络分类置信度的从大到小的结果
  • 关乎symbol和module的一些基本属性
# 查看json每一个op的属性:kernel size、padding、stride等
sym.attr_dict() # 返回一个字典,根据key获取对应op的属性
# 查看网络的输出name
sym.list_outputs()
# 查看网络所有的输入节点name
sym.list_arguments()
# 查看网络所有内部节点
sym.get_internals()
# 获取网络的参数节点name
mod.get_params()[0]
# 获取网络的中间结果 fc7 output
all_layers = sym.get_internals()
sym = all_layers['fc7_output']
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None) # 然后做一次inference就能获取fc7 output
  • 遗留问题mod设置label_names=None时候,会提示一个warning目前不清楚怎么解决
//由于json里面有一个输入节点为softmax_label导致在做inference的时候总是会提示label_shapes的warninig,但是实际上在做inference的时候是不需要输入softmax_label的
Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])

猜你喜欢

转载自blog.csdn.net/wwwhp/article/details/84556909