【pytorch】保存模型、加载模型

1、torch 的 save 和 load

我们可以直接使用 save 函数 和 load函数 进行存储和读取。

  • save 使用 Python 的 pickle 实用程序将对象进行序列化,然后将序列化的对象保存到disk。save可以保存各种对象,包括模型、张量和字典等。
  • load 使用 pickle unpickle 工具将 pickle 的对象文件反序列化为内存
import torch

x = [1, 2]
y = {
    
    'name':'xiaoming', 'age':16}
z = (x, y)

torch.save(z, 'z.pt')

z_new = torch.load('z.pt')
print(z_new)   # ([1, 2], {'name': 'xiaoming', 'age': 16})

2、state_dict

1)net.state_dict()

PyTorch中,Module 的可学习参数 (即权重和偏差),模块模型包含在参数中 (通过 model.parameters() 访问)。state_dict 是一个从参数名称隐射到参数 Tesnor 的有序字典对象。
注意,只有具有可学习参数的层(卷积层、线性层等) 才有 state_dict 中的条目。

import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP()
print(net.state_dict())

# OrderedDict([('hidden.weight', tensor([[-0.2360, -0.3193, -0.2618],[ 0.1759, -0.0888,  0.2635]])), 
#              ('hidden.bias', tensor([ 0.2161, -0.3944])), 
#              ('output.weight', tensor([[-0.5358, -0.2140]])), 
#              ('output.bias', tensor([0.6262]))])

2)optimizer.state_dict()

优化器(optim) 也有一个 state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

import torch
import torch.nn as nn


net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

print(optimizer.state_dict())
# {'state': {}, 
#  'param_groups': [{'lr': 0.001, 
#                    'momentum': 0.9, 
#                    'dampening': 0, 
#                    'weight_decay': 0, 
#                    'nesterov': False, 
#                    'maximize': False, 
#                    'foreach': None, 
#                    'params': [0, 1, 2, 3]}]}


3、保存模型 和 加载模型

1)仅保存和加载模型参数(state_dict)

import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))

# 保存模型参数
torch.save(net.state_dict(), 'model_weight.pt')   # 推荐的文件后缀名是pt或pth

# 下载模型参数
model_weight = torch.load('model_weight.pt')
print(model_weight)
# OrderedDict([('0.weight', tensor([[-0.3865, -0.4623,  0.1212],[-0.2480,  0.3840,  0.1916]])), 
#              ('0.bias', tensor([ 0.0698, -0.0641])), 
#              ('2.weight', tensor([[-0.1499, -0.2895]])), 
#              ('2.bias', tensor([0.2585]))])

# 下载模型参数 并放到模型中
net.load_state_dict(torch.load('model_weight.pt'))

2)保存 和 加载 整个模型

import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))

# 保存整个模型
torch.save(net, 'model.pt')

# 下载模型参数 并放到模型中
net_new = torch.load('model.pt')
print(net_new)
# Sequential(
#   (0): Linear(in_features=3, out_features=2, bias=True)
#   (1): ReLU()
#   (2): Linear(in_features=2, out_features=1, bias=True)
# )

猜你喜欢

转载自blog.csdn.net/weixin_37804469/article/details/129139443