先有一个模型:
my_resnet = MyResNet(*args, **kwargs)
两种加载权重方法:
1.基于推荐保存的方式
保存方式:
torch.save(my_resnet.state_dict(), "my_resnet.pth")
对应的加载方式:
my_resnet.load_state_dict(torch.load("my_resnet.pth"))
- 直接torch.load
my_resnet = torch.load("my_resnet.pth")
加载部分预训练模型
PyTorch 中的 torchvision 里已经有很多常用的模型了,可以直接调用:AlexNet 、VGG 、ResNet 、SqueezeNet 、DenseNet
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict()
# 将 pretrained_dict 里不属于 model_dict 的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的 model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict)