MXNET深度学习框架-13-读写存模型

        如果一个深度网络(几十层的网络)在训练时出现突然断电,内存溢出,电脑蓝屏等情况,是不是会很抓狂?所以模型存储就变得很重要。本章我们在mxnet下学习如何进行模型存储与读写。

1、读写NDArray
        (1) NDArray是mxnet中的一个科学计算库,下面我们来实现以下怎么存储NDArray的参数:

x = nd.ones(shape=3)
y = nd.ones(shape=2)
filename1 = "13-模型存储/test1_array.params"
nd.save(filename1, [x, y])

结果:
在这里插入图片描述
可以看到,参数确实已经被保存在了该文件夹下。
        不仅仅是NDArray,dict也是一样的:

mydict={"x":x,"y":y}
filename2 = "13-模型存储/test2_dict.params"
nd.save(filename2, mydict)

运行结果:
在这里插入图片描述
        (2) 读模型

a, b = nd.load(filename1)  # 模型保存路径和名字
c=nd.load(filename2)
print(a, b,c)

结果:
在这里插入图片描述
        (3) 神经网络模型存储与读写

# 随便定义一个网络
def get_net():
    net=gn.nn.Sequential()  #nn.block
    with net.name_scope():
        net.add(gn.nn.Dense(10,activation="relu"))
        net.add(gn.nn.Dense(2))
    return net
net1=get_net()
net1.initialize()
x=nd.random_normal(shape=(3,10),scale=0.01)  # 输入x
print(net1(x))
# 下面把模型参数存起来
net_filename="13-模型存储/mlp.params"
net.save_parameters(net_filename) # 存模型参数

结果:
在这里插入图片描述
在这里插入图片描述
下面把模型读出来:

net2=get_net() # 重新加载一个网络
net2.load_parameters(net_filename)
print(net2(x))

结果:
在这里插入图片描述
从上图可知,net1和net2的结果是一样的。

附上所有源码:

import mxnet.ndarray as nd
import mxnet.gluon as gn


# '''---模型存储---'''
# x = nd.ones(shape=3)
# y = nd.ones(shape=2)
# filename1 = "13-模型存储/test1_array.params"
# nd.save(filename1, [x, y])
# # 不仅仅是NDArray,dict也是一样的
# mydict={"x":x,"y":y}
# filename2 = "13-模型存储/test2_dict.params"
# nd.save(filename2, mydict)
#
# '''---读取模型---'''
# a, b = nd.load(filename1)  # 模型保存路径和名字
# c=nd.load(filename2)
# print(a, b,c)

'''---读写gluon的参数---'''
# 随便定义一个网络
def get_net():
    net=gn.nn.Sequential()  #nn.block
    with net.name_scope():
        net.add(gn.nn.Dense(10,activation="relu"))
        net.add(gn.nn.Dense(2))
    return net
net1=get_net()
net1.initialize()
x=nd.random_normal(shape=(3,10),scale=0.01)  # 输入x
print(net1(x))
# 下面把模型参数存起来
net_filename="13-模型存储/mlp.params"
# net.save_parameters(net_filename) # 存模型参数
# 下面把模型读出来
net2=get_net() # 重新加载一个网络
net2.load_parameters(net_filename)
print(net2(x))
发布了88 篇原创文章 · 获赞 39 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/Daker_Huang/article/details/105557748