[pytorch] save model, load model

1. Save and load of torch

We can directly use the save function and load function to store and read.

  • save uses Python's pickle utility to serialize an object and then saves the serialized object to disk. save can save various objects, including models, tensors, and dictionaries.
  • load Use the pickle unpickle tool to deserialize the pickled object file into memory
import torch

x = [1, 2]
y = {
    
    'name':'xiaoming', 'age':16}
z = (x, y)

torch.save(z, 'z.pt')

z_new = torch.load('z.pt')
print(z_new)   # ([1, 2], {'name': 'xiaoming', 'age': 16})

2、state_dict

1)net.state_dict()

In PyTorch, Module the learnable parameters (i.e. weights and biases) of the module model are contained in the parameters ( model.parameters()accessed ). state_dictis an ordered dictionary object mapping from parameter names to Tesnorparameters .
Note that only layers with learnable parameters (convolutional, linear, etc.) have entries state_dictin .

import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP()
print(net.state_dict())

# OrderedDict([('hidden.weight', tensor([[-0.2360, -0.3193, -0.2618],[ 0.1759, -0.0888,  0.2635]])), 
#              ('hidden.bias', tensor([ 0.2161, -0.3944])), 
#              ('output.weight', tensor([[-0.5358, -0.2140]])), 
#              ('output.bias', tensor([0.6262]))])

2)optimizer.state_dict()

There is also one for optimizer( optim) state_dict, which contains information about the state of the optimizer and the hyperparameters used.

import torch
import torch.nn as nn


net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

print(optimizer.state_dict())
# {'state': {}, 
#  'param_groups': [{'lr': 0.001, 
#                    'momentum': 0.9, 
#                    'dampening': 0, 
#                    'weight_decay': 0, 
#                    'nesterov': False, 
#                    'maximize': False, 
#                    'foreach': None, 
#                    'params': [0, 1, 2, 3]}]}


3. Save the model and load the model

1) Only save and load model parameters ( state_dict)

import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))

# 保存模型参数
torch.save(net.state_dict(), 'model_weight.pt')   # 推荐的文件后缀名是pt或pth

# 下载模型参数
model_weight = torch.load('model_weight.pt')
print(model_weight)
# OrderedDict([('0.weight', tensor([[-0.3865, -0.4623,  0.1212],[-0.2480,  0.3840,  0.1916]])), 
#              ('0.bias', tensor([ 0.0698, -0.0641])), 
#              ('2.weight', tensor([[-0.1499, -0.2895]])), 
#              ('2.bias', tensor([0.2585]))])

# 下载模型参数 并放到模型中
net.load_state_dict(torch.load('model_weight.pt'))

2) Save and load the entire model

import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))

# 保存整个模型
torch.save(net, 'model.pt')

# 下载模型参数 并放到模型中
net_new = torch.load('model.pt')
print(net_new)
# Sequential(
#   (0): Linear(in_features=3, out_features=2, bias=True)
#   (1): ReLU()
#   (2): Linear(in_features=2, out_features=1, bias=True)
# )

Guess you like

Origin blog.csdn.net/weixin_37804469/article/details/129139443