Pytorch:模型的保存与加载

模型保存与加载常用有两种方法,第一种是保存整个模型,包括模型的结构和参数;第二种是保存模型的参数。推荐使用第二种,因为模型一旦很大,第一种加载耗时长,其次第二种加载方式更加灵活,可以加载其他模型的预训练参数,从而使用迁移学习的方法减小训练时长。

一、保存/加载整个模型

  1. 保存模型:
	torch.save(net, 'model_net1.pkl')
  1. 加载模型
	net_parm =  'model_net1.pkl'
    net = torch.load(net_parm)

二、保存/加载模型参数

  1. 保存参数:
  	 torch.save({
          'epoch': nums_epoch,
          'state_dict': net.state_dict(),
      }, 'model_net.pkl')
  1. 加载参数:
	  cuda_gpu = torch.cuda.is_available()
      if(cuda_gpu):
            net = torch.nn.DataParallel(net, device_ids=gpus).cuda()
      if os.path.isfile(net_parm):
            print("=>loading model '{}'".format(net_parm))
            checkpoint = torch.load(net_parm)
            epoch = checkpoint['epoch']
            print(epoch)
            net.load_state_dict(checkpoint['state_dict'])
            print("=>load model success, start epoch: '{}'".format(epoch))

torch.load 返回的是一个 OrderedDict。OrderedDict 是 collections 提供的一种数据结构, 它提供了有序的dict结构。可以将加载的模型打印出来:

	checkpoint = torch.load(net_parm)
	print (checkpoint)

在这里插入图片描述
因此我们知道参数加载的原理就是将相同的Key进行赋值操作,load一般是依据key来加载的,一旦有key不匹配则出错。

3. 部分加载参数
这时候就出现一个问题,如果我们修改了网络结构,但是相同的部分想加载预训练模型的参数,应该如何加载呢?

net.load_state_dict(pretrained_net, strict=False)

设置strict=False,则直接忽略不匹配的key,对于匹配的key则进行正常的赋值。
或者自己设置一个过滤器,过滤不需要的网络层。参考apaszke推荐的做法,即删除与当前model不匹配的key。

pretrained_dict = ...
model_dict = model.state_dict()

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(pretrained_dict)
发布了29 篇原创文章 · 获赞 120 · 访问量 6万+

猜你喜欢

转载自blog.csdn.net/qq_21578849/article/details/86573043