Pytorch之——模型保存与加载


一、序列化与反序列化

在这里插入图片描述
序列化:将内存中的数据转成二进制的数存储在硬盘中
反序列化:相反的过程

二、pytorch中的模型保存与加载的两种方式

1.模型保存——torch.save()

主要参数:
obj:对象
f:输出路径

保存的两种方式
在这里插入图片描述

#构建路径
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

2.模型加载——torch.load()

主要参数:
f:文件路径(与torch.save中的f相同)
map_location:指定存放位置,cpu or gpu

#load net
path_model = "./model.pkl"
net_load = torch.load(path_model)
#load state_dict
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)

这里要注意把参数加载过来之后要把加载进来的参数赋值给新的模型

#这里以保存LeNet 然后赋值给新的LeNet为例
net_new = LeNet2(classes=2019)
net_new.load_state_dict(state_dict_load)

三、断点续训练——checkpoint(适用于机器忽然停电后继续训练模型时)

在这里插入图片描述
对于一个模型,我们要知道保存的地方应该是在模型和优化器中的参数

1.模型保存

#这里以每5个epoch保存一次参数为例讲解

checkpoint_interval=5
 if (epoch+1) % checkpoint_interval == 0:

     checkpoint = {
    
    "model_state_dict": net.state_dict(),  #保存模型参数
                   "optimizer_state_dict": optimizer.state_dict(),  #保存优化器参数
                   "epoch": epoch}     #保存第几次epoch
     path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)   #保存模型参数路径
     torch.save(checkpoint, path_checkpoint)

2.模型加载

#在定义好损失函数时,模型接着训练前,加载参数并赋值让模型继续训练
path_checkpoint = "./checkpoint_4_epoch.pkl"       
checkpoint = torch.load(path_checkpoint)             #加载参数

net.load_state_dict(checkpoint['model_state_dict'])  #加载模型参数

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  #加载优化器参数 

start_epoch = checkpoint['epoch']                              #加载开始的epoch

scheduler.last_epoch = start_epoch        #因为设置了学习率的下降方式,这里也要更新一下。如果没设置可以不用管

参考

深度之眼pytorch深度学习框架班

猜你喜欢

转载自blog.csdn.net/weixin_43183872/article/details/108322214
今日推荐