pytorch storage model

Only store model parameters

import torch
import torch.nn as nn

# 定义一个简单的模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 创建一个模型实例
model = MyModel()
path = 'output/model'

# 保存模型
torch.save(model.state_dict(), path+'.pth')

# 创建一个新的模型对象(与保存模型的模型结构相同)
model_1 = MyModel()
# 加载模型参数
model_1.load_state_dict(torch.load(path+'.pth'))

# 判断两个模型参数是否相同
# 判断模型参数是否相同
params_equal = True
for (a_name, a_param), (b_name, b_param) in zip(model.state_dict().items(), model_1.state_dict().items()):
    if a_name != b_name or not torch.equal(a_param, b_param):
        params_equal = False
        break

if params_equal:
    print("模型参数相同")
else:
    print("模型参数不相同")

operation result:

insert image description here

save the whole model

import torch
import torch.nn as nn

# 定义一个简单的模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 创建一个模型实例
model = MyModel()
path = 'output/model'

# 保存模型
torch.save(model, path+'.pth')

# 加载模型
model_1 = torch.load(path+'.pth')

# 判断两个模型参数是否相同
# 判断模型参数是否相同
params_equal = True
for (a_name, a_param), (b_name, b_param) in zip(model.state_dict().items(), model_1.state_dict().items()):
    if a_name != b_name or not torch.equal(a_param, b_param):
        params_equal = False
        break

if params_equal:
    print("模型参数相同")
else:
    print("模型参数不相同")

Guess you like

Origin blog.csdn.net/weixin_42173136/article/details/131558861