pytorch保存、加载模型, 并将网络模型.pt保存为ONNIX

pytorch的模型和参数是分开的,可以分别保存或加载模型和参数。

简单说,原先的net在保存之前,要eval一下,load之后的net也要eval一下,把所有参数freeze掉。才保证两个net完全相同(输入相同tensor得到完全一致的结果),具体原因参见:pytorch模型的保存与加载注意事项

pytorch有两种模型保存方式:

一、保存整个神经网络的的结构信息和模型参数信息,save的对象是网络net。

二、只保存神经网络的训练模型参数,save的对象是net.state_dict()。
对应两种保存模型的方式,pytorch也有两种加载模型的方式。对应第一种保存方式,加载模型时通过torch.load(’.pth’)直接初始化新的神经网络对象;对应第二种保存方式,需要首先导入对应的网络,再通过net.load_state_dict(torch.load(’.pth’))完成模型参数的加载。

在网络比较大的时候,第一种方法会花费较多的时间。

Pytorch两种模型保存方式
1,只保存模型参数

# 保存
torch.save(model.state_dict(), '\parameter.pt')
# 加载
model = TheModelClass(...)
model.load_state_dict(torch.load('\parameter.pt'))
model.eval()

在保存模型进行推理时,只需保存经过训练的模型的学习参数即可。使用 torch.save() 函数 保存模型的 state_dict 将为以后恢复模型提供最大的灵活性,这就是为什么推荐使用它来保存模型。

一个常见的PyTorch约定是使用 .pt 或 .pth 文件扩展名保存模型。
请记住,在运行推理之前,必须调用 model.eval() 来将 dropout 和 batch normalization layers 设置为评估模式。如果不这样做,就会产生不一致的推理结果。

2, 保存完整模型

保存:
torch.save(model, PATH)

加载:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

这个保存/加载过程使用最直观的语法,涉及的代码最少。以这种方式保存模型将使用Python的 pickle 模块保存整个model。 这种方法的缺点是序列化数据被绑定到保存模型时使用的特定类和精确的目录结构。 原因是pickle没有保存模型类本身。相反,它保存到包含类的文件的路径,该类在加载时使用。 正因为如此,当您在其他项目中使用时或在重构之后,您的代码可能以各种方式中断。

.pt转为ONNX
原始的代码如下:

import torch
if __name__ == "__main__":
    # 保存为onnx格式
    model = torch.load("E:/age_gender_model.pt")
    # set the model to inference mode
    model.eval()
    model.cpu()
    dummy_input1 = torch.randn(1, 3, 64, 64)
    torch.onnx.export(model, (dummy_input1), "age_gender_model.onnx", verbose=True)

但是运行会报错:
AttributeError: Can’t get attribute ‘xxxNet’ on <module ‘main’ from

经查询,必须引入xxxNet模型的定义文件。添加以下代码即可解决问题。

from age_gender_cnn import MyMulitpleTaskNet

或者引入全部:

from age_gender_cnn import  *

需要注意的是在导出为ONNX模型之前,也必须要执行eval, 如下图所示
在这里插入图片描述

另外官方的推荐方法(我还没来得及看):
https://pytorch.org/docs/master/notes/serialization.html

另外,推荐一篇博客:
PyTorch学习:加载模型和参数

Guess you like

Origin blog.csdn.net/thequitesunshine007/article/details/119306517