[Pytorch]保存与加载神经网络模型

一、保存神经网络模型

        Pytorch模型以pth格式进行保存,为了方便操作我们可以设置两种保存模式:保存一定轮次的模型数据和保存验证集损失函数最小的模型数据。同时也可保存损失函数日志以便分析。代码实现如下:

torch.save(model.state_dict(), os.path.join(保存路径, "文件名"))

        损失函数可以保存在txt文件中

    def append_loss(self, epoch, loss, val_loss):
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        self.losses.append(loss)
        self.val_loss.append(val_loss)

        with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
            f.write(str(loss))
            f.write("\n")
        with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
            f.write(str(val_loss))
            f.write("\n")

        self.writer.add_scalar('loss', loss, epoch)
        self.writer.add_scalar('val_loss', val_loss, epoch)
        #可以添加图像绘制代码

        轮次保存和最佳保存可以添加逻辑代码予以控制。

if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
    #保存代码
if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
    #保存代码

二、神经网络模型加载

        加载的模型一般有两类,一种是直接由pytorch保存的.pth文件,另一种是通用文件.onnx。其中onnx是微软推出的通用网络模型标准,可以实现跨平台部署。onnx文件的加载会在以后描述。

#直接加载
self.net.load_state_dict(torch.load('文件路径', map_location=device))
#分层加载--可以检测不匹配的层
model_dict      = model.state_dict()    #加载模型字典用于比对
pretrained_dict = torch.load(model_path, map_location = device)
load_key, no_load_key, temp_dict = [], [], {}
for k, v in pretrained_dict.items():
    if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
        temp_dict[k] = v
        load_key.append(k)       #匹配成功的层
    else:
        no_load_key.append(k)    #匹配失败的层
model_dict.update(temp_dict)
model.load_state_dict(model_dict)

猜你喜欢

转载自blog.csdn.net/weixin_37878740/article/details/128724667