Detailed explanation of saving and loading models in PyTorch (2)

If you review the past and learn the new, you can become a teacher!

1. Reference materials

SAVING AND LOADING MODELS
pytorch model saving, loading and continued training
Detailed explanation of saving and loading models in PyTorch (1)

2. Model saving and loading

1. Build a network model

import torch
import torch.nn as nn
import torch.nn.functional as F


# 模型定义
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    
#模型初始化
model = Net()

2. Method 1 (recommended)

2.1 Save model

# 保存模型
torch.save(model.state_dict(), './model/model_state_dict.pth')

The parameter after this method'./model/model_state_dict.pth' is the saving path of the model. Officially recommended model suffixes are .pth and .pt , of course, other suffixes are also possible.

2.2 Load model

Useload_state_dict() to load the model, first use theload() method to deserialize the saved model parameters The result after a>, is a dictionary, which can be loaded through the method. load()load_state_dict()

model_test1 = Net()   # 加载模型时应先实例化模型

# 加载模型
model_test1.load_state_dict(torch.load('./model/model_state_dict.pth'))
model_test1.eval()    # 模型推理时设置

load_state_dict()The function receives a dictionary, so './model/model_state_dict.pth' cannot be passed in directly. Instead, the load() function is used to save the model parameters< a i=3>Deserialization.

Serialization is to save the data in the memory to the disk. Using the torch.save() method to save the model is serialization; while deserialization is to load the data from the hard disk into the memory. , obviously the process of loading the model is the deserialization process.

Insert image description here

3. Method 2 (recommended)

If training is terminated abnormally for some reason, the checkpoint method can be used to easily continue training from the last time. Because of this, it is highly recommended that you use this method to save and load models.

3.1 Save model

# 保存checkpoint
torch.save({
    
    
            'epoch':epoch,
            'model_state_dict':model.state_dict(),
            'optimizer_state_dict':optimizer.state_dict(),
            'loss':loss
            }, './model/model_checkpoint.tar'    #这里的后缀名官方推荐使用.tar
            )

3.2 Load model

# 加载checkpoint
model = Net()
optimizer =  torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load('./model/model_checkpoint.tar')    # 先反序列化模型
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

3.3 Code examples

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#1、准备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)

#2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)

#3、搭建神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, input):
        input = self.model1(input)
        return input


#4、创建网络模型
model = Net()
model.to(device)

#5、设置损失函数
loss_fun = nn.CrossEntropyLoss()   #交叉熵函数

# 设置优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), learning_rate)   #SGD:梯度下降算法

#6、设置网络训练中的一些参数
Max_epoch = 10    #设置训练轮数
total_train_step = 0   #记录总训练次数
total_test_step = 0    #记录总测试次数

#7、开始进行训练
for epoch in range(Max_epoch):
    print("---第{}轮训练开始---".format(epoch))

    model.train()     #开始训练,不是必须的,在网络中有BN,dropout时需要
    #由于训练集数据较多,这里我没用训练集训练,而是采用测试集(test_dataset_loader)当训练集,但思想是一致的
    for data in test_dataset_loader:  # 遍历所有batch
        imgs, targets = data
        imgs, targets = imgs.to(device), targets.to(device)
        
        #反向传播,更新参数
        optimizer.zero_grad()  # 重置每个batch的梯度
        outputs = model(imgs)  # 前向传播计算预测值
        loss = loss_fun(outputs, targets) # 计算当前损失
        loss.backward()  # 反向传播计算梯度
        optimizer.step()  # 更新所有的参数

        total_train_step += 1

        if total_train_step % 50 == 0:
            print("---第{}次训练结束, Loss:{})".format(total_train_step, loss.item()))

    if epoch > 5:
        print("---意外中断---")
        break
	
	if (epoch+1) % 2 == 0:
        # 保存checkpoint
        torch.save({
    
    
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, './model/model_checkpoint_epoch_{}.tar'.format(epoch)  # 这里的后缀名官方推荐使用.tar
        )

Save once for two epochs. When epoch=6, set a break to simulate the unexpected interruption of the program. After the interruption, you can see the output information of the terminal, as shown in the following figure:
Insert image description here

As you can see from the picture above, the program was interrupted during the 6th cycle. The latest saved model at this time is the fifth training result, as shown in the picture below:
Insert image description here

At the same time, we noticed that the loss at the end of the fifth training was around 2.0. If we continue training next time, the loss should be around 2.0.

At this time, continue training based on the results of the last training. The code is as follows:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#1、准备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)

#2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)

#3、搭建神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, input):
        input = self.model1(input)
        return input


#4、创建网络模型
model = Net()
model.to(device)

#5、设置损失函数
loss_fun = nn.CrossEntropyLoss()   #交叉熵损失

# 设置优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), learning_rate)   #SGD:梯度下降算法

#6、设置网络训练中的一些参数
Max_epoch = 10    #设置训练轮数
total_train_step = 0   #记录总训练次数
total_test_step = 0    #记录总测试次数

##########################################################################################
# 加载checkpoint
checkpoint = torch.load('./model/model_checkpoint_epoch_5.tar')    # 先反序列化模型
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
##########################################################################################

#7、开始进行训练
for epoch in range(start_epoch+1, Max_epoch):
    print("---第{}轮训练开始---".format(epoch))

    model.train()     #开始训练,不是必须的,在网络中有BN,dropout时需要
    for data in test_dataset_loader:  # 遍历所有batch
        imgs, targets = data
        imgs, targets = imgs.to(device), targets.to(device)

        #反向传播,更新参数
        optimizer.zero_grad()  # 重置每个batch的梯度
        outputs = model(imgs)  # 前向传播计算预测值
        loss = loss_fun(outputs, targets)  # 计算当前损失
        loss.backward()  # 反向传播计算梯度
        optimizer.step()  # 更新所有的参数

        total_train_step += 1

        if total_train_step % 50 == 0:
            print("---第{}次训练结束, Loss:{})".format(total_train_step, loss.item()))

    if (epoch+1) % 2 == 0:
        # 保存checkpoint
        torch.save({
    
    
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
        }, './model/model_checkpoint_epoch_{}.tar'.format(epoch)  # 这里的后缀名官方推荐使用.tar
        )

The code here has an extra loading processcheckpoint compared to the previous one. I intercepted it, as shown in the figure below:
Insert image description here
By loading checkpoint, we save the parameters of the previous training, and then continue the training at the breakpoint. Let’s directly look at the results of executing this code, as shown in the figure below:
Insert image description here

As can be seen from the figure above, training starts from the 6th round, and the initial loss is 1.99, which is close to 2.0. This shows that we have implemented the operation of resuming training after interruption.

4. Method 3

4.1 Save model

# 保存模型
torch.save(model, './model/model.pt')    #这里我们保存模型的后缀名取.pt

4.2 Load model

# 加载模型
model_test2 = torch.load('./model/model.pt')     
model_test2.eval()   # 模型推理时设置

This method is not recommended, because if you save the model in this way, you will encounter various errors when loading. In order to deepen everyone's understanding, let's give an example. The structure of the file is shown below:
Insert image description here

models.pyThe file stores the definition of the model, which is located under the folder models. save_model.pyWritten in the file is the code to save the model. The code is as follows:

from models.models import Net
from torch import optim
import torch


#模型初始化
model = Net()

# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# ## 保存加载方式2——save/load
# # 保存模型
# torch.save(models, './models/models.pt')

After executing this file, the models.pt file will be generated. We can then execute the load_mode.py file to load it. load_mode.py The content is as follows:

from models.models import Net
import torch


## 加载方式2
# 加载模型
model_test2 = Net()
model_test2 = torch.load('./models/models.pt')     
model_test2.eval()   # 模型推理时设置
print(model_test2)

At this point we can load normally. But if we modify the models folder to model, as shown below:
Insert image description here

At this time, if we use the following code to load the model, an error will occur:

from models.models import Net
import torch


## 加载方式2
# 加载模型
model_test2 = Net()
model_test2 = torch.load('./model/models.pt')     #这里需要修改一下文件路径  
model_test2.eval()   # 模型推理时设置
print(model_test2)

Insert image description here
The reason for this error is that when saving the model using this method, the path to the model structure definition file will be recorded. When loading, it will be parsed according to the path and then the parameters will be loaded; when the model definition file is After the path is modified, an error will be reported when using torch.load(path).

In fact, there are still various problems when using this method to save and load models. If you are interested, you can read thisblog post . In short, in our future use, try not to use this method to load models.

Guess you like

Origin blog.csdn.net/m0_37605642/article/details/134191729