MXNet官方文档中文版教程(8):使用预训练模型预测

文档英文原版参见Predict with pre-trained models

本教程介绍如何使用预训练模型识别图像中的对象,以及如何进行特征提取。

前提条件

为了完成以下教程,我们需要:

  • MXNet:安装教程
    -Python Requests, Matplotlib and Jupyter Notebook.
$ pip install requests matplotlib jupyter opencv-python

载入

我们首先下载一个预训练的,在完整的ImageNet数据集上训练的152层的ResNet模型,该数据集拥有超过1000万张图像和1万个类别。预训练模型包含两部分,包含模型定义的json文件和包含参数的二进制文件。 此外,还可能有一个用于标签的文本文件。

import mxnet as mx
path='http://data.mxnet.io/models/imagenet-11k/'
[mx.test_utils.download(path+'resnet-152/resnet-152-symbol.json'),
 mx.test_utils.download(path+'resnet-152/resnet-152-0000.params'),
 mx.test_utils.download(path+'synset.txt')]

接下来,我们载入下载的模型。注意:如果GPU可用,我们可以用mx.gpu() 替换所有出现的mx.cpu(),以加速计算。

sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], 
         label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
with open('synset.txt', 'r') as f:
    labels = [l.rstrip() for l in f]

预测

我们首先定义帮助函数来下载图像并进行预测:

%matplotlib inline
import matplotlib.pyplot as plt
import cv2
import numpy as np
# define a simple data batch
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

def get_image(url, show=False):
    # download and show the image
    fname = mx.test_utils.download(url)
    img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB)
    if img is None:
         return None
    if show:
         plt.imshow(img)
         plt.axis('off')
    # convert into format (batch, RGB, width, height)
    img = cv2.resize(img, (224, 224))
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)
    img = img[np.newaxis, :]
    return img

def predict(url):
    img = get_image(url, show=True)
    # compute the predict probabilities
    mod.forward(Batch([mx.nd.array(img)]))
    prob = mod.get_outputs()[0].asnumpy()
    # print the top-5
    prob = np.squeeze(prob)
    a = np.argsort(prob)[::-1]
    for i in a[0:5]:
        print('probability=%f, class=%s' %(prob[i], labels[i]))

现在,我们可以使用任何可下载的URL进行预测:

predict('http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg')
predict('http://thenotoriouspug.com/wp-content/uploads/2015/01/Pug-Cookie-1920x1080-1024x576.jpg')

特征提取

通过特征提取,意思是通过内层的输出而不是最后的softmax层来表示输入图像。可以将这些输出视为原始输入图像的特征,然后可以用于其他方面(如目标检测)。

我们可以使用get_internals 方法从Symbol获取所有中间层。

# list the last 10 layers
all_layers = sym.get_internals()
all_layers.list_outputs()[-10:]

经常使用的特征提取层是最后的全连接层之前的那层。对于ResNet,还有Inception,它是一个名为flatten0 的平展层,该层将4维卷积层输出重构为2维的全连接层。以下代码提取了一个新符号,它用来输出平展层并创建模型。

fe_sym = all_layers['flatten0_output']
fe_mod = mx.mod.Module(symbol=fe_sym, context=mx.cpu(), label_names=None)
fe_mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
fe_mod.set_params(arg_params, aux_params)

我们现在可以调用forward 来获取特征:

img = get_image('http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg')
fe_mod.forward(Batch([mx.nd.array(img)]))
features = fe_mod.get_outputs()[0].asnumpy()
print(features)
assert features.shape == (1, 2048)

猜你喜欢

转载自blog.csdn.net/qq_36165459/article/details/78394434