【人工智能概论】 神经网络模型的保存与读取

【人工智能概论】 神经网络模型的保存与读取


一. 前言

  • 搭建并训练好神经网络模型后,可以将模型以及相应参数进行保存,方便后续调用。
  • PyTorch提供了两套读写方式,方式一:模型结构与参数都保存,方式二:只保存参数。
  • 模型的保存时机也是个可以考虑的点。

二. 同时保存模型结构与参数

  • 以torchvision提供的vgg16模型为例。
  • 模型保存:
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 不仅保存网络模型,也保存网络模型中的相关参数
torch.save(vgg16, "vgg16_model1.pth")
  • 模型加载:
import torch
import torchvision

model = torch.load("vgg16_model1.pth")
print(model) # 查看模型结构

三. 只保存模型参数

  • 仍以torchvision提供的vgg16模型为例。
  • 模型保存:
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 只保存模型的参数,占用空间更小,官方推荐
torch.save(vgg16.state_dict(), "vgg16_model2.pth")
  • 模型加载:
import torch
import torchvision

# 相较于第一种方法,该方法在加载时需把模型构建起来。
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load("vgg16_model2.pth"))
print(model)

四. 保存的时机——轮次保存和最佳保存

  • 可以采用轮次保存最佳保存,添加相应的逻辑代码即可实现。

4.1 轮次保存

  • 每间隔若干轮保存一次,能得到多组历史参数信息;
  • 这有个好处,万一过拟合了,还可以用过往的参数。
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
    #保存代码

4.2 最佳保存

  • 不同轮数下的训练效果会有波动,在不发生过拟合的前提下,这样可以获得表现最佳的参数数据。
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
    #保存代码

猜你喜欢

转载自blog.csdn.net/qq_44928822/article/details/130262673