PyTorch模型的保存与提取

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/weixin_44613063/article/details/95379739

训练好一个模型后,最好将它直接保存下来,等以后需要的时候就可以直接提取调用,十分方便也节约了很多时间。

三个核心功能:

  1. torch.save:将序列化的对象保存到 disk。这个函数使用Python的pickle实用程序进行序列化。使用这个函数可以保存各种对象的模型、张量和字典。
  2. torch.load:使用 pickle unpickle 工具将 pickle 的对象文件反序列化为内存。
  3. torch.nn.Module.load_state_dict:使用反序列化状态字典加载 models 参数字典。

保存模型

 torch.save(model, PATH)	# 保存整个网络
 torch.save(model.state_dict(), PATH)	 # 只保存网络中的参数 (速度快, 占内存少)

假如训练出的模型取名为 net1

保存的例子:

torch.save(net1, 'net.pkl')
torch.save(net1.state_dict(), 'net_params.pkl')

提取模型

model = torch.load(PATH)	# 提取整个网络

model = ModelClass(*args, **kwargs)		# ModelClass 为已定义的网络
model.load_state_dict(torch.load(PATH))		# 提取模型参数

提取整个网络的例子:

net2 = torch.load('net.pkl')
prediction = net2(x)

仅提取参数的例子:

net3 = torch.nn.Sequential(
	torch.nn.Linear(1, 10),
	torch.nn.ReLU(),
	torch.nn.Linear(10, 1)
)

net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)

猜你喜欢

转载自blog.csdn.net/weixin_44613063/article/details/95379739