【pytorch】(断点)继续上次训练

every blog every motto: You can do more than you think.

0. 前言

训练过程中,停止,后续接着训练

1. 正文

1.1 保存信息

每个eopch以后需要保存后续接着训练的信息,信息包括,model、optimizer、epoch

for epoch in range(start_epoch,end_epoch):
	for iter ,data in enumerate(dataloader):
		pass
		
	# -------------------------------------------------------
    # 每个epoch 后保存checkpoint,以便断点继续训练
   checkpoint = {
    
    
       'eopch': epoch,
       'model_state_dict': self.net.state_dict(),
       'optimizer_state_dict': self.optimizer.state_dict()
   }
   torch.save(checkpoint, os.path.join(self.save_chpt,
                                       'epoch_%d_loss_%3f.pth'.format(epoch, epoch_fuse_loss / ite_num_per_epoch)))
   print('保存各参数完成,用于后续继续训练。')
   	# -------------------------------------------------------

1.2 继续训练

需要先实例化模型和优化器,然后进行如下操作

if self.subsequent_training:  # 如果是断点继续上次训练
    checkpoints = torch.load(os.path.join(self.save_chpt, 'xxx.pth'))
    self.start_epoch = checkpoints['epoch'],
    self.optimizer.load_state_dict(checkpoints['optimizer']),
    self.net.load_state_dict(checkpoints['model'])

    print('继续上次训练,各参数为:', checkpoints)

参考文献

[1] https://zhuanlan.zhihu.com/p/375461811
[2] https://www.zhihu.com/question/313486088?sort=created
[3] https://zhuanlan.zhihu.com/p/133250753
[4] https://www.jianshu.com/p/1cd6333128a1

猜你喜欢

转载自blog.csdn.net/weixin_39190382/article/details/120394587