Pytorch在加载的模型基础上继续训练

深度学习网络模型的训练往往会花费挺长时间,这时候万一断电了,机器死机了,那真的气不打一处来,想砸机器的冲动都来了有没有?

不过也不用太着急,一般咱们的模型都写有模型参数保存功能,比如这样:

if epoch%10 == 1:
	torch.save(model.state_dict(),'{}/moilenetV2_{}_{}.pth.format('./models',epoch,acc))

我们只需要找到这个模型保存的位置,然后把最新的这个模型参数加载到我们的model中,就可以接着这个参数进行训练了。要加载的代码一般放在model定义之后(就是确定model的结构了),模型进行训练之前。要加载代码如下:

Resume = True
# Resume = False
if Resume:
	path_checkpoint = 'your/new/model/path.pth'
	checkpoint = torch.load(path_checkpoint, map_location = torch.device('cpu'))
	model.load_state_dict(checkpoint)

变量Resume可以作为开关,如果想在训练好的模型基础上进行finetune(微调)的话,就把它设置为True,从零训练的话就设置为False。当然咱们这种出问题,接着训练的就设置为True就行。

知识扩充

训练模型的保存包括两种:
1、保存整个模型框架以及模型参数(存储文件过大,不推荐)

torch.save(model,path)

2、仅仅保存模型的参数文件(推荐)

torch.save(model.state_dict(),path)

"state_dict"表示state dictionary,即字典类型的参数,模型本身的参数。

其中torch.load()函数可以加载模型参数,为了保证GPU显存够用,推荐令map_location = torch.device(‘cpu’)

假如你就想加载到gpu中,可以令map_location = torch.device(‘cuda’)

最后用model.load_state_dict(checkpoint)把参数加载完成。

好了,快去训练你的模型吧!有问题欢迎留言~

猜你喜欢

转载自blog.csdn.net/WYKB_Mr_Q/article/details/118546342