Pytorch之模型加载/保存

pytorch保存模型有两种方法:

  1. 保存整个模型 (结构+参数)
  2. 只保存参数(官方推荐)

两者都是用torch.save(obj, dir)实现,这个函数的作用是将对象保存到磁盘中,它的内部是使用Python的pickle实现。
两种方法的区别其实就是obj参数的不同:前者的obj是整个model对象,后者的obj是从model里获取的存储了model参数的词典,推荐用第二种,虽然麻烦了一丁点,但是比较灵活,有利于实现预训练、参数迁移等操作。

保存整个模型

这种方法很简单,保存和加载就两行代码,和Python pickle包的用法是一样的,把model当作一个对象直接保存加载就行。

# 保存
model = Mymodel()
torch.save(model, path)
# 加载 
model = torch.load(path)

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。

保存参数

重点介绍一下这种方法,一般训完一个模型之后我们不会单独只保存一个模型的参数,为了方便后续操作,比如恢复训、参数迁移等,我们会保存当前转态的一个快照,具体信息可以根据自己的需要,下面列出几个方面:

  • 模型参数
  • 优化器参数
  • loss
  • epoch
  • args

把这些信息用字典包装起来,然后保存即可。

这种方式保存的模型只是它的参数,所以我们在加载时需要先创建好模型,然后再把参数加载进去,如下:

# 获得保存信息
save_data = {
    
    
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    'epoch': epoch,
    'args': args
     ...
}
# 保存
torch.save(save_data , path)
load_data = torch.load(path)
model = Mymodel()
optimizer = Myoptimizer()
# 加载参数
model.load_state_dict(load_data ['model_state_dict'])
optimizer.load_state_dict(load_data ['optimizer_state_dict'])
...

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。

猜你喜欢

转载自blog.csdn.net/MoreAction_/article/details/107967053