版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
训练好一个模型后,最好将它直接保存下来,等以后需要的时候就可以直接提取调用,十分方便也节约了很多时间。
三个核心功能:
torch.save
:将序列化的对象保存到 disk。这个函数使用Python的pickle实用程序进行序列化。使用这个函数可以保存各种对象的模型、张量和字典。torch.load
:使用 pickle unpickle 工具将 pickle 的对象文件反序列化为内存。torch.nn.Module.load_state_dict
:使用反序列化状态字典加载 models 参数字典。
保存模型
torch.save(model, PATH) # 保存整个网络
torch.save(model.state_dict(), PATH) # 只保存网络中的参数 (速度快, 占内存少)
假如训练出的模型取名为 net1
保存的例子:
torch.save(net1, 'net.pkl')
torch.save(net1.state_dict(), 'net_params.pkl')
提取模型
model = torch.load(PATH) # 提取整个网络
model = ModelClass(*args, **kwargs) # ModelClass 为已定义的网络
model.load_state_dict(torch.load(PATH)) # 提取模型参数
提取整个网络的例子:
net2 = torch.load('net.pkl')
prediction = net2(x)
仅提取参数的例子:
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)