PyTorch保存和加载模型(全面汇总)

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数张量建立映射关系.(如model的每一层的weights及偏置等等)

只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等。按理说BN是没有参数可保存的,然而实际上在resnet中是有保存的,因为pytorch的nn.BatchNorm2d默认affine =True,是存在映射的参数的。

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等),但是似乎很少会加载Optimizer。

参考
https://blog.csdn.net/vino_cherish/article/details/84110401
https://www.jianshu.com/p/60fc57e19615
https://www.cnblogs.com/leebxo/p/10920134.html

保存/加载模型参数

#----保存----
torch.save(model.state_dict(), 'params_name.pth') #保存的文件名后缀一般是.pt或.pth
#----加载----
model=Model() #定义模型结构
model.load_state_dict(torch.load('params_name.pth'))  #加载模型参数

这种方法只保存/加载模型参数,不保存/加载模型结构,是官方推荐的方法。

就是要先自己定义模型结构,然后用上面的加载模型参数。

保存/加载模型

#----保存----
torch.save(model, 'model_name.pth')
#----加载----
model = torch.load('model_name.pth')

这种方法会同时保存和加载模型的参数和结构信息。加载时不需要自己定义模型结构,直接从预训练模型中得到模型结构和参数。

只加载重合的部分参数

如果自己的模型跟预训练模型只有部分层是相同的,那么可以只加载这部分相同的参数,只要设置strict参数为False来忽略那些没有匹配到的keys即可。

#----保存----
torch.save(modelA.state_dict(), PATH)
#----加载----
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

自己选择要保存的参数

上面两种方法都是很笼统的加载模型结构或者参数,如果我们想要保存一些其他的变量,可以这样做:

#----保存----
torch.save({
    'epoch': epoch + 1,
    'arch': args.arch,
    'state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
  	'loss': loss,
    'best_prec1': best_prec1,}, 
    'checkpoint_name.tar' )

#----加载----
checkpoint = torch.load('checkpoint_name.tar')

#按关键字获取保存的参数
start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
state_dict=checkpoint['state_dict']

model=Model()#定义模型结构
model.load_state_dict(state_dict)

保存多个模型进一个文件中

既然checkpoint中可以按关键字保存不同的参数,那也就可以用关键字标识不同模型的参数来保存,然后同样按关键字加载不同模型的参数。

#----保存----
torch.save({
  'modelA_state_dict': modelA.state_dict(),
  'modelB_state_dict': modelB.state_dict(),
  'optimizerA_state_dict': optimizerA.state_dict(),
  'optimizerB_state_dict': optimizerB.state_dict(),
  ...
  }, PATH)

#----加载----
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelAClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict']
modelB.load_state_dict(checkpoint['modelB_state_dict']
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']

modelA.eval()
modelB.eval()
# or
modelA.train()
modelB.train()

在这里,保存完模型后加载的时候有时会遇到CUDA out of memory的问题,我google到的解决方法是加上map_location=‘cpu’

checkpoint = torch.load(PATH,map_location='cpu')

查看模型中某些层的参数

假设模型的网络结构如下:

# 定义一个网络
from collections import OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
# 打印网络的结构
print(model)
 
OUT:
Sequential (
(conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(relu1): ReLU ()
(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(relu2): ReLU ()
)

查看获取conv1的weight和bias:

model=Model() #定义模型结构
model.load_state_dict(torch.load('params_name.pth'))  #加载模型参数

params=model.state_dict()
for k,v in params.items():
    print(k) #打印网络中的变量名
print(params['conv1.weight']) #打印conv1的weight
print(params['conv1.bias']) #打印conv1的bias

加载部分预训练模型

有的时候想用预训练模型的参数,但是自己的模型定义的层名字不同,无法直接匹配加载,此时可以遍历state_dict,最后加载。(下面的例子是只加载重合参数的意思,其实用strict就行了,但主要是参考他如何遍历state_dict以及最后更新模型参数)

resnet152 = models.resnet152(pretrained=True) #加载模型结构和参数
pretrained_dict = resnet152.state_dict()
"""加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
   也可以直接从官方model_zoo下载:
   pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

或者写详细一点:

model_dict = model.state_dict()
state_dict = {}
for k, v in pretrained_dict.items():
    if k in model_dict.keys():
        # state_dict.setdefault(k, v)
        state_dict[k] = v
    else:
        print("Missing key(s) in state_dict :{}".format(k))
model_dict.update(state_dict)
model.load_state_dict(model_dict)

加载部分预训练模型基本都是这么做的,这个是最详细的:Pytorch模型迁移和迁移学习,导入部分模型参数,他还包括了修改预训练模型参数名的方法(其实也是很简单的):

def string_rename(old_string, new_string, start, end):
    new_string = old_string[:start] + new_string + old_string[end:]
    return new_string
 
 
def modify_model(pretrained_file, model, old_prefix, new_prefix):
    '''
    :param pretrained_file:
    :param model:
    :param old_prefix:
    :param new_prefix:
    :return:
    '''
    pretrained_dict = torch.load(pretrained_file)
    model_dict = model.state_dict()
    state_dict = modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix)
    model.load_state_dict(state_dict)
    return model
 
def modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix):
    '''
    修改model dict
    :param pretrained_dict:
    :param model_dict:
    :param old_prefix:
    :param new_prefix:
    :return:
    '''
    state_dict = {}
    for k, v in pretrained_dict.items():
        if k in model_dict.keys():
            # state_dict.setdefault(k, v)
            state_dict[k] = v
        else:
            for o, n in zip(old_prefix, new_prefix):
                prefix = k[:len(o)]
                if prefix == o:
                    kk = string_rename(old_string=k, new_string=n, start=0, end=len(o))
                    print("rename layer modules:{}-->{}".format(k, kk))
                    state_dict[kk] = v
    return state_dict

此外还有一个具体的例子:Pytorch自由载入部分模型参数并冻结

构建一个字典,使得字典的keys和我们自己创建的网络相同,我们再从各种预训练网络把想要的参数对着新的keys填进去就可以有一个新的state_dict了,这样我们就可以load这个新的state_dict,目前只能想到这个方法应对较为复杂的网络变换。

保存torch.nn.DataParallel模型

使用单GPU训练和使用DataParallel模块进行多GPU训练所保存的模型有所不同,主要是DataParallel保存模型时多了一个module关键字,所以当使用DataParallel来加载一个不使用DataParallel训练出来的模型时,就会报错。
有一个技巧是:

  1. 如果预训练模型是用DataParallel训练的,那么我们就先做model=DataParallel(model)然后再加载预训练模型
  2. 如果预训练模型没有用DataParallel训练,那么我们就先加载预训练模型,再做model=DataParallel(model)

torchvision.models预训练模型

torchvision.models这个包中包含alexnet、densenet、inception、resnet、 squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。

import torchvision 
model = torchvision.models.resnet50(pretrained=True)#获取网络结构和预训练模型

pretrained=True 会加载预训练模型,他的加载源码为:

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo   # model_zoo是和导入预训练模型相关的包
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']
# model_urls这个字典是预训练模型的下载地址
model_urls = {    
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

若当前模型的网络结构层与model_zoo提供的预训练模型的网络结构层严格对等:

def resnet50(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 

网络结构层不对等(事实上跟上面完全相同,只是多写了几个步骤而已):

def resnet50_cbam(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_state_dict = model_zoo.load_url(model_urls['resnet50']) 
        #load_url函数根据model_urls字典下载或导入相应的预训练模型
        now_state_dict = model.state_dict()   # 返回当前model模块的字典
        now_state_dict.update(pretrained_state_dict)
        model.load_state_dict(now_state_dict) #最后通过调用model的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)
    return model
发布了131 篇原创文章 · 获赞 12 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/weixin_41519463/article/details/103205665