pytorch模型保存

pytorch模型保存方式:

pytorch官网手册

torchvision.models中的下载下来的pretrained模型,比如:alexnet-owt-4df8aa71.pth,是将模型的参数,按照有序字典的方式保存。类型为:collections.OrderedDict

方式一:仅保存和加载模型参数(推荐的方式)

1 . 保存模型参数:

import torch
torch.save(model.state_dict(), 'save_path_name.pth') # 后缀不重要(torchvision.models下载下来的模型参数后缀为.pth)

2 . 加载模型参数:

import torch
import torch.nn as nn
model.load_state_dict(torch.load('save_path_name.pth'), strict=True) # strict=True表示键值要严格匹配

例子:

import torch
import torch.nn as nn
import torchvision
import AlexNet_Train_Val

model = AlexNet_Train_Val.ModifiedAlexNet(2)
model.load_state_dict(torch.load('./models/AlexNet1.pth'), strict=True) # AlexNet1.pth为之前保存的模型参数文件

方式二:保存和加载整个模型(模型结构和模型参数)

1 . 保存模型:

import torch
torch.save(model, 'save_path_name.pth')

2 . 加载模型:

import torch
import torch.nn as nn
model = torch.load('save_path_name.pth')

猜你喜欢

转载自blog.csdn.net/tsq292978891/article/details/79434892