#coding:utf-8
'''
读取和存储
'''
from mxnet import nd
from mxnet.gluon import nn
import mxnet as mx
x = nd.ones(3)
y = nd.zeros(4)
filename = 'datas/test1.params'
nd.save(filename,[x,y])
a,b = nd.load(filename)
print(a,b)
# 读写Gluon模型参数
def get_net():
net = nn.Sequential()
with net.name_scope():
net.add(nn.Dense(10,activation='relu'))
net.add(nn.Dense(2))
return net
net =get_net()
net.initialize()
x = nd.random.uniform(shape=(2,10))
print(net(x))
filename = 'datas/mlp.params'
net.save_params(filename)
net2 = get_net()
net2.load_params(filename,mx.cpu())
print(net2(x))
MXNet动手学深度学习笔记:模型参数保存与读取
猜你喜欢
转载自my.oschina.net/wujux/blog/1809879
今日推荐
周排行