pytorch模型保存方式:
torchvision.models中的下载下来的pretrained模型,比如:alexnet-owt-4df8aa71.pth,是将模型的参数,按照有序字典的方式保存。类型为:collections.OrderedDict
方式一:仅保存和加载模型参数(推荐的方式)
1 . 保存模型参数:
import torch
torch.save(model.state_dict(), 'save_path_name.pth') # 后缀不重要(torchvision.models下载下来的模型参数后缀为.pth)
2 . 加载模型参数:
import torch
import torch.nn as nn
model.load_state_dict(torch.load('save_path_name.pth'), strict=True) # strict=True表示键值要严格匹配
例子:
import torch
import torch.nn as nn
import torchvision
import AlexNet_Train_Val
model = AlexNet_Train_Val.ModifiedAlexNet(2)
model.load_state_dict(torch.load('./models/AlexNet1.pth'), strict=True) # AlexNet1.pth为之前保存的模型参数文件
方式二:保存和加载整个模型(模型结构和模型参数)
1 . 保存模型:
import torch
torch.save(model, 'save_path_name.pth')
2 . 加载模型:
import torch
import torch.nn as nn
model = torch.load('save_path_name.pth')