[PyTorch] Model saving and loading

1. Save only weight information

# 模型路径
path = "state_dict_model.pt"

# 保存
torch.save(model.state_dict(), path)

# 加载
model = Network()
# 将训练好的权重加载到模型中
model.load_state_dict(torch.load(path))

2. Save all information

# 对整个模型进保存和加载
path = "entire_model.pt"

# 保存模型
torch.save(model, path)

# 加载模型
model = torch.load(path)

3. Save the checkpoint

# 保存checkpoint
path = 'model.pt'
torch.save(
    {
    
    
        'epoch':epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict':optimizer.state_dict(),
        'loss': loss_fn
    },path
)

# 加载
model = Network(input_num)
optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=lr)

checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

4. Other tests

When we print the model:

net = MyModel(3)

print(net.state_dict().items()) # 输出模型每一层的权重
print(net.state_dict())   # 输出模型每一层的权重

Guess you like

Origin blog.csdn.net/See_Star/article/details/127559560