保存和加载pytorch模型

当保存和加载模型时,需要熟悉三个核心功能:

torch.save:将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。
torch.load:使用pickle的unpickling功能将pickle对象文件反序列化到内存。此功能还可以有助于设备加载数据。
torch.nn.Module.load_state_dict:使用反序列化函数 state_dict 来加载模型的参数字典。

Python中对于模型数据的保存和加载操作都是引用Python内置的pickle包,使用pickle.dump()和pickle.load()方法。在Pytorch中也有同样功能的方法提供。

>>>torch.save(model,'model.pkl') #保存整个模型
>>>model = torch.load('model.pkl') #加载整个模型
>>>torch.save(alexnet.state_dict(),'params.pkl') #保存网格中的参数
>>>alexnet.load_state_dict(torch.load('params.pkl')) #加载网格中的参数

在torchvision.models模块里,PyTorch提供了一些常用的模型:

常用模型
AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3

可以使用torch.util.model_zoo来预加载它们,具体设置通过参数pretrained=True来实现。

>>>import torchvision.models as models
>>>ResNet18 = models.ResNet18(pretrained=True)
>>>alexnet = models.alexnet(pretrained=True)
>>>squeezenet = models.squeezenet1_0(pretrained=True)
>>>vgg16 = models.vgg16(pretrained=True)
>>>densenet = models.densenet161(pretrained=True)
>>>inception = models.inception_v3(pretrained=True)

加载这类预训练模型的过程中,还可以进行微处理。

>>>pretrained_dict = model_zoo.load_url(model_urls['resnet134'])
>>>model_dict = model.state_dict()
>>>pretrained_dict = {
    
    k:v for k,v in pretrained_dict.items()if k in model_dict}
   #将pretrained_dict里不属于model_dict的键剔除掉
>>>model_dict.update(pretrained_dict)  #更新现在有的model_dict
>>>model.load_state_dict(model_dict) 

参考
《PyTorch机器学习从入门到实战》

猜你喜欢

转载自blog.csdn.net/weixin_45656790/article/details/108890049