pytorch调用预训练模型

最近刚开始入手pytorch,搭网络要比tensorflow更容易,有很多预训练好的模型,直接调用即可。
参考链接

import torch
import torchvision.models as models #预训练模型都在这里面
#调用alexnet模型,pretrained=True表示读取网络结构和预训练模型,False表示只加载网络结构,不需要预训练模型
alexnet = models.alexnet(pretrained=False)

print(alexnet)  # 打印模型结构
# 加载预先下载好的预训练参数到alexnet
alexnet.load_state_dict(torch.load('F:/DeepLearning/alexnet-owt-4df8aa71.pth'))
print(alexnet)  # 打印的还是模型结构
pre_dict = alexnet.state_dict()  # 按键值对将模型参数加载到pre_dict
print((k,v) for k ,v in pre_dict.items())  # 打印模型参数
for k ,v in pre_dict.items():  #打印模型每层命名
    print(k)

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
#note:model是自己定义好的模型,将pretrained_dict和model_dict中命名一致的层加入pretrained_dict(包括参数)

VGG模型同理:

vgg16 = models.vgg16(pretrained=True) #加载网络结构和预训练模型
#static_dict()返回包含模块所有状态的字典
pretrained_dict = vgg16.state_dict()  #返回内置预训练vgg模块的字典
model_dict = model.state_dict()  #返回我们自己model的字典

#------------------------最关键的三步------------------------------------------
# 1. filter out unnecessary keys,也就是说从内置模块中删除掉我们不需要的字典
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 2. overwrite entries in the existing state dict,利用pretrained_dict更新现有的model_dict
model_dict.update(pretrained_dict)

# 3. load the new state dict,更新模型,加载我们真正需要的state_dict
model.load_state_dict(model_dict)

保存和加载模型可以参考链接

猜你喜欢

转载自blog.csdn.net/aaon22357/article/details/82696938