gluon训练出的模型转成mx.mode.Module可用的symbol

6. 保存成Symbol格式的网络和参数(重点)

要注意保存网络参数的时候,需要net.collect_params().save()这样保存,而不是net.save_params()保存
最新版的mxnet已经有可以导出到symbol格式下的接口了。需要mxnet版本在20171015以上
下面示例代码也已经改成新版的保存,加载方式

#新版本的保存方式
net.export('Gluon_FashionMNIST')

7. 使用Symbol加载网络并绑定

symnet = mx.symbol.load('Gluon_FashionMNIST-symbol.json')
mod = mx.mod.Module(symbol=symnet, context=mx.cpu())
mod.bind(data_shapes=[('data', (1, 1, 28, 28))])
mod.load_params('Gluon_FashionMNIST-0000.params')
Batch = namedtuple('Batch', ['data'])

8. 预测试试看效果

img,label = fashion_test[random.randint(0, 60000)]
data = img.transpose([2,0,1])
data = data.reshape([1,1,28,28])
mod.forward(Batch([data]))
out = mod.get_outputs()
prob = out[0]
predicted_labels = prob.argmax(axis=1)

plt.imshow(img.reshape((28, 28)).asnumpy())
plt.axis('off')
plt.show()
print('predicted labels:',get_text_labels(predicted_labels.asnumpy()))

print('true labels:',get_text_labels([label]))

猜你喜欢

转载自blog.csdn.net/insanegtp/article/details/79990356