[tf]tf.supervisor和新的tf.train.MonitoredTrainingSession()

tf.train.Supervisor已经被弃用了官方建议使tf.train.MonitoredTrainingSession()作为代替。

tf.supervisor

Supervisor可以自动的帮我们做一些事情比如:

  • 自动去checkpoint加载数据(如果有checkpoint那么就加载这个checkpoint,如果没有的话那么就从0开始训练)。
  • 自动进行全局变量的初始化。
  • 自身有一个saver,用于保存checkpoint
  • 自身就有一个summary_computed用来保存summary,不需要我们自己写入。
import tensorflow as tf
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
tf.scalar_summary("a",a)
init_op = tf.global_variables_initializer() 
merged_summary_op = tf.merge_all_summaries()
sv = tf.train.Supervisor(logdir="/home/keith/tmp/",init_op=init_op) #logdir用来保存checkpoint和summary
saver=sv.saver #创建saver
with sv.managed_session() as sess: #会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
    for i in xrange(1000):
        update_ = sess.run(update)
        print update_
        if i % 10 == 0:
            merged_summary = sess.run(merged_summary_op)
            sv.summary_computed(sess, merged_summary,global_step=i)
        if i%100 == 0:
            saver.save(sess,logdir="/home/keith/tmp/",global_step=i)

tf.train.MonitoredTrainingSession()

tf.train.MonitoredTrainingSession(
    master='',
    is_chief=True,
    checkpoint_dir=None,
    scaffold=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=USE_DEFAULT,
    save_summaries_steps=USE_DEFAULT,
    save_summaries_secs=USE_DEFAULT,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100,
    max_wait_secs=7200,
    save_checkpoint_steps=USE_DEFAULT,
    summary_dir=None
)
  • is_chief意思是在分布式集群中是否是主节点。
  • 如果save_summaries_stepssave_summaries_secs都是None的时候,则默认100个step保存一次。
  • config: 一个 tf.ConfigProto类,用来配置session
    一个使用的例子。
  with monitored_session.MonitoredTrainingSession(
      master=master,
      is_chief=is_chief,
      checkpoint_dir=logdir,
      scaffold=scaffold,
      hooks=hooks,
      chief_only_hooks=chief_only_hooks,
      save_checkpoint_secs=save_checkpoint_secs,
      save_summaries_steps=save_summaries_steps,
      config=config,
      max_wait_secs=max_wait_secs) as session:
    loss = None
    while not session.should_stop():
      loss = session.run(train_op)
  return loss

一个常用的小例子:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir = './Checkpoints',
    hooks = [hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step),
                   tf.train.NanTensorHook(loss)]
    save_checkpoint_steps = 100) as sess:

一个用于训练的完整的小例子:

    with tf.train.MonitoredTrainingSession(
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step),
                   tf.train.NanTensorHook(loss),
                   _LoggerHook()],  # 将上面定义的_LoggerHook传入
            config=tf.ConfigProto(
                log_device_placement=False)) as sess:

        coord = tf.train.Coordinator()
        # 开启文件读取队列
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        while not sess.should_stop():
            sess.run(train_op)

        coord.request_stop()
        coord.join(threads)

MonitoredTrainingSession继承自MonitoredSession

当MonitoredSession初始化的时候,会按顺序执行下面操作:

  • 调用hookbegin()函数,我们一般在这里进行一些hook内的初始化。比如在上面猫狗大战中的_LoggerHook里面的_step属性,就是用来记录执行步骤的,但是该参数只在本类中起作用。

  • 通过调用scaffold.finalize()初始化计算图

  • 创建会话

  • 通过初始化Scaffold提供的操作(op)来初始化模型

  • 如果checkpoint存在的话,restore模型的参数

  • launches queue runners

    扫描二维码关注公众号,回复: 5233631 查看本文章
  • 调用hook.after_create_session()
    然后,当run()函数运行的时候,按顺序执行下列操作:

  • 调用hook.before_run()

  • 调用TensorFlow的 session.run()

  • 调用hook.after_run()

  • 返回用户需要的session.run()的结果

  • 如果发生了AbortedError或者UnavailableError,则在再次执行run()之前恢复或者重新初始化会话
    最后,当调用close()退出时,按顺序执行下列操作:

  • 调用hook.end()

  • 关闭队列和会话

  • 阻止OutOfRange错误
    需要注意的是:该类不是一个tf.Session() ,因为它不能被设置为默认会话,不能被传递给saver.save,也不能被传递给tf.train.start_queue_runners,这也解释了为什么在开启会话后我们必须手动调用tf.train.start_queue_runners()

各种Hook

  • tf.train.SummarySaverHook:如果summary_writer没有给定,但是output_dir给定了那么就会创建一个writer
__init__(
    save_steps=None,
    save_secs=None,
    output_dir=None,
    summary_writer=None,
    scaffold=None,
    summary_op=None
)
11452592-055eb8e99a4f94d0.png
image.png
  • chekpointSaverHook:
saver_hook = tf.train.CheckpointSaverHook(
        checkpoint_dir = model_dir,
        save_steps = 100
)

猜你喜欢

转载自blog.csdn.net/weixin_33709364/article/details/87128130
tf
今日推荐