pytorch模型保存和加载

        定义2个测试脚本test.py和test2.py,用于测试保存和加载,models文件夹保存模型,整个测试的项目文件结构如下:

E:.
│  test.py
│  test2.py
└─ models
        dongtai.pt
        dongtai_state_dict.pt
        jingtai.pth

test.py中定义了TheModelClass这个网络结构类,此外写了模型保存和加载的代码,test2.py是想测试在没有定义模型结构的脚本中,是否可以成功加载模型。

test.py

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
 
 
# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
 
if __name__ == "__main__":

    # 初始化模型
    model = TheModelClass()
    

    # 模型保存,方法一动态图
    torch.save(model,"models/dongtai.pt")
    # 模型保存,方法二动态图
    torch.save(model.state_dict(),"models/dongtai_state_dict.pt")

    # 模型保存,方法三静态图
    x = torch.rand(1,3,30,30)   #占位符
    trace_model = torch.jit.trace(model,x) 
    torch.jit.save(trace_model,"models/jingtai.pth")


    # 模型加载,带模型结构
    model_resume = torch.load("models/dongtai.pt")

    # 模型加载,只有权重
    weights = torch.load("models/dongtai_state_dict.pt")
    model.load_state_dict(weights)

    # 直接从静态图中恢复,无需模型结构
    model = torch.jit.load("models/jingtai.pth")
    x = torch.rand(1,3,30,30) 
    pred = model(x)
    print(pred)

经过测试,pytorch可以通过三种方法实现模型的保存和加载:

  • 动态图保存模型结构和权重
  • 动态图保存权重
  • 静态图保存权重

接下来一个个说明这三种方法需要注意的地方。

一、动态图保存模型结构和权重

# 保存
model = TheModelClass()

# 模型保存
torch.save(model,"models/dongtai.pt")

# 加载,带模型结构
model_resume = torch.load("models/dongtai.pt")

保存:首先实例化网络对象,然后通过torch.save的方式,将模型结构和权重都序列化保存下来,后缀为pt或者pth都可以,不管保存成哪种后缀,都可以解析。

加载:首先必须能访问到网络结构的类TheModelClass,然后通过torch.load的方式就可以完整的将模型结构恢复,同时加载好权重。

这里需要特别注意的点,加载模型的这个文件必须要能找到网络结构的类,不管是在哪里定义网络,都要能导入到当前读取模型的这个文件中做实例化,比如我在test2.py里面导入test.py中的网络类,就可以成功加载,否则会报找不到类的错误。

test2.py

import torch
from test import TheModelClass  # 不导入或同级下找不到会有问题

model = TheModelClass()  
model_resume = torch.load("models/dongtai.pt")
model.load_state_dict(model_resume)
model.eval()
print()

二、动态图保存权重

# 初始化模型
model = TheModelClass()

# 模型保存
torch.save(model.state_dict(),"models/dongtai_state_dict.pth")

# 模型加载,只有权重
weights = torch.load("models/dongtai_state_dict.pth")
model.load_state_dict(weights)

保存:首先实例化网络对象,然后通过torch.save的方式,只将模型权重序列化保存下来,这种方法不用保存模型结构。

加载:首先必须能访问到网络结构的类TheModelClass,并实例化,然后通过torch.load的方式就可以将模型权重反序列化取出,然后将其加载进模型对象中。

注意:必须实例化网络对象,才能加载对应的权重。

三、静态图保存权重

model = TheModelClass()

# 模型保存,方法三静态图
x = torch.rand(1,3,30,30)   #占位符
trace_model = torch.jit.trace(model,x) 
torch.jit.save(trace_model,"models/jingtai.pt")

# 直接从静态图中恢复,无需模型结构
model_ji = torch.jit.load("models/jingtai.pt")

保存:首先实例化网络对象,然后用一个随机的固定尺寸的输入,通过torch.jit.trace,将网络结构前向跑一遍,记录下网络中的节点运行路径,然后通过torch.jit.save将这个运行路径存下来,这种方法会自动记录模型中节点间的数据流动顺序,也就是间接的记录下的模型结构和每个节点的权重。不会单独再保存一个模型类。

加载:直接用torch.jit.load的方法加载模型即可,因为该模型已经记录了网络中模型节点权重和数据流动的路径,因此只要将数据输入,即可“流过”整个模型,得到最终的输出,不用单独再构造模型类的实例。

总结

目前用的最多就是只保存权重的方法(方法二),最后一种用的最少,一般部署的时候也很少用,都是转成onnx再部署。

猜你喜欢

转载自blog.csdn.net/sinat_33486980/article/details/127793348