[Introduction to Artificial Intelligence] Saving and reading neural network models

[Introduction to Artificial Intelligence] Saving and reading neural network models


I. Introduction

  • After building and training the neural network model, you can save the model and corresponding parameters for subsequent calls.
  • PyTorch provides two sets of reading and writing methods. Method 1: save both the model structure and parameters, and method 2: save only the parameters.
  • The timing of saving the model is also a point that can be considered.

2. Save model structure and parameters at the same time

  • Take the vgg16 model provided by torchvision as an example.
  • Model save:
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 不仅保存网络模型,也保存网络模型中的相关参数
torch.save(vgg16, "vgg16_model1.pth")
  • Model loading:
import torch
import torchvision

model = torch.load("vgg16_model1.pth")
print(model) # 查看模型结构

3. Only save model parameters

  • Still taking the vgg16 model provided by torchvision as an example.
  • Model save:
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 只保存模型的参数,占用空间更小,官方推荐
torch.save(vgg16.state_dict(), "vgg16_model2.pth")
  • Model loading:
import torch
import torchvision

# 相较于第一种方法,该方法在加载时需把模型构建起来。
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load("vgg16_model2.pth"))
print(model)

4. Timing of saving – round saving and best saving

  • You can use 轮次保存and 最佳保存and add the corresponding logic code to achieve it.

4.1 Round saving

  • Save once every several rounds, multiple sets of historical parameter information can be obtained;
  • This has the advantage that in case of overfitting, you can still use past parameters.
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
    #保存代码

4.2 Best Save

  • The training effect under different rounds will fluctuate. This can obtain the best performing parameter data without overfitting.
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
    #保存代码

Guess you like

Origin blog.csdn.net/qq_44928822/article/details/130262673