tensorflow 1.6 修改checkpoint的saver机制

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/a5nan/article/details/91958664

前段时间公司上马了一个机器学习的项目,在基础环境搭好以后,默认机制存在几个痛点:

  • 每隔10分钟保持一个checkpoint,
  • 保存下来的step无法追溯其loss值,
  • 最多只给保存5个点,
  • 无法获取最小loss的checkpoint

这些需求其实跟tensorflow本身的业务无关,只是修改一下他训练过程中的保存机制。下面记录一下这些问题的解决办法。

1. 最多只给保存5个点:
这个最先搜到了答案:tf.train.saver()中添加一个参数:max_to_keep=10,默认是5,我的目录是:object_detection/trainer.py。
2.每隔10分钟保持一个checkpoint:
这个是搜索log信息反追的代码,每次保存checkpoint的时候会有一个log:

Saving checkpoint to path...

搜出来的结果在这个文件里:lib/python3.6/site-packages/tensorflow/python/training/supervisor.py。这个init函数里面有个参数save_model_secs他的参数注解里面有说默认是600秒,设为0就取消定时保存。哪里在调他呢?
lib/python3.6/site-packages/tensorflow/contrib/slim/python/slim/learning.py。其中supervisor.Supervisor(… … save_model_secs = save_interval_secs … …)。save_interval_secs 正好是在train()这个启动训练的函数所在的地方(trainer.py里面slim.learning.train()启动训练)。
**结论:**修改train()的save_interval_secs的值来修改定时保存的值,0为取消定时保存。
3.无法获取最小loss的checkpoint:
这个最开始想要用earlystop的方法,但由于对模型训练过程中的收敛过程一无所知,也就不知道如何设置stop相关的参数,于是放弃了。
目前用的方法是每一个step跑完就去检查loss值,把最佳loss值的checkpoint保存下来,同时关掉定时保存,把最大保存数量设为了10。具体实现如下:
也是用了log信息反追代码的方法,每个step跑完就有一个log:

global step xxx: loss = xxx (xxx sec/step)

这个log就是在learning.py里面train_step(),train()里面train_step_fn=train_step,于是找到train()中调用train_step_fn() 的地方,然后把迭代保存最佳loss的代码加进去,

if total_loss < best_loss:
  best_loss = total_loss
  sv.saver.save(sess, sv.save_path, global_step = sv_global_step)

4.保存下来的step无法追溯其loss值:
我的办法是把loss值加到文件名里面,把上面的代码稍加修改:

sv.saver.save(sess, sv.save_path + “-” + str(total_loss), global_step = sv_global_step)

这样就可以在保存目录里面看到保存下来的step对应的loss值。
结果分析:
这样保存下来的结果是最佳loss的step的checkpoint和在其出现之前9个迭代过的历史,注意:他们并不是本次训练的最佳10个结果,他们只是迭代的历史纪录而已。

猜你喜欢

转载自blog.csdn.net/a5nan/article/details/91958664
1.6
今日推荐