6. 保存成Symbol格式的网络和参数(重点)
要注意保存网络参数的时候,需要net.collect_params().save()
这样保存,而不是net.save_params()
保存
最新版的mxnet已经有可以导出到symbol格式下的接口了。需要mxnet版本在20171015以上
下面示例代码也已经改成新版的保存,加载方式
#新版本的保存方式
net.export('Gluon_FashionMNIST')
7. 使用Symbol加载网络并绑定
symnet = mx.symbol.load('Gluon_FashionMNIST-symbol.json')
mod = mx.mod.Module(symbol=symnet, context=mx.cpu())
mod.bind(data_shapes=[('data', (1, 1, 28, 28))])
mod.load_params('Gluon_FashionMNIST-0000.params')
Batch = namedtuple('Batch', ['data'])
8. 预测试试看效果
img,label = fashion_test[random.randint(0, 60000)]
data = img.transpose([2,0,1])
data = data.reshape([1,1,28,28])
mod.forward(Batch([data]))
out = mod.get_outputs()
prob = out[0]
predicted_labels = prob.argmax(axis=1)
plt.imshow(img.reshape((28, 28)).asnumpy())
plt.axis('off')
plt.show()
print('predicted labels:',get_text_labels(predicted_labels.asnumpy()))
print('true labels:',get_text_labels([label]))