读tf.estimator.Estimator源码之train方法(史上最详细,欢迎留言讨论)

读tf.estimator.Estimator源码之train方法(史上最详细)

**注意,本文阅读的时候由于有大量代码,最后用pycharm对着代码一边看一边读。**

先贴源码:

  def train(self,
            input_fn,
            hooks=None,
            steps=None,
            max_steps=None,
            saving_listeners=None):
    """Trains a model given training data `input_fn`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
        See [Premade Estimators](
        https://tensorflow.org/guide/premade_estimators#create_input_functions)
        for more information. The function should construct and return one of
        the following:  * A
        `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
        `(features, labels)` with same constraints as below. * A tuple
        `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
        of string feature name to `Tensor` and `labels` is a `Tensor` or a
        dictionary of string label name to `Tensor`. Both `features` and
        `labels` are consumed by `model_fn`. They should satisfy the expectation
        of `model_fn` from inputs.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      steps: Number of steps for which to train the model. If `None`, train
        forever or train until `input_fn` generates the `tf.errors.OutOfRange`
        error or `StopIteration` exception. `steps` works incrementally. If you
        call two times `train(steps=10)` then training occurs in total 20 steps.
        If `OutOfRange` or `StopIteration` occurs in the middle, training stops
        before 20 steps. If you don't want to have incremental behavior please
        set `max_steps` instead. If set, `max_steps` must be `None`.
      max_steps: Number of total steps for which to train model. If `None`,
        train forever or train until `input_fn` generates the
        `tf.errors.OutOfRange` error or `StopIteration` exception. If set,
        `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the
        middle, training stops before `max_steps` steps. Two calls to
        `train(steps=100)` means 200 training iterations. On the other hand, two
        calls to `train(max_steps=100)` means that the second call will not do
        any iteration since first call did all 100 steps.
      saving_listeners: list of `CheckpointSaverListener` objects. Used for
        callbacks that run immediately before or after checkpoint savings.

    Returns:
      `self`, for chaining.

    Raises:
      ValueError: If both `steps` and `max_steps` are not `None`.
      ValueError: If either `steps` or `max_steps <= 0`.
    """
    _estimator_api_gauge.get_cell('train').set(True)
    if self.config.task_type in (run_config.TaskType.EVALUATOR,
                                 run_config.TaskType.PS):
      raise ValueError(
          'Train has been called wrong configuration. Please use '
          'tf.estimator.train_and_evaluate which calls proper API according '
          'to given configuration. Current configuration: {}.'.format(
              self.config))

    with context.graph_mode():
      if (steps is not None) and (max_steps is not None):
        raise ValueError('Can not provide both steps and max_steps.')
      if steps is not None and steps <= 0:
        raise ValueError('Must specify steps > 0, given: {}'.format(steps))
      if max_steps is not None and max_steps <= 0:
        raise ValueError(
            'Must specify max_steps > 0, given: {}'.format(max_steps))

      if max_steps is not None:
        start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
        if max_steps <= start_step:
          logging.info('Skipping training since max_steps has already saved.')
          return self

      hooks = _check_hooks_type(hooks)
      hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))

      saving_listeners = _check_listeners_type(saving_listeners)
      loss = self._train_model(input_fn, hooks, saving_listeners)
      logging.info('Loss for final step: %s.', loss)
      return self

参数:

  1. input_fn:用于给训练过程提供minibatches的数据的函数,使用详情可以参考;Premade Estimators。该函数的返回值必须是以下几种之一:
    (1) A tf.data.Dataset object: Dataset 的输出必须是(features, labels) 元组,它的格式要求和下面相同。
    (2) A tuple(features, labels): 其中, features是一个 tf.Tensor 或者是以string为key,以Tensor为value的字典。 labels 同理。 featureslabels 都是供 model_fn消费的, 它们必须满足model_fn 的输入要求。

  2. hooks: 一个包含若干 tf.train.SessionRunHook 子类实例的list,用于在训练过程中的回调。说点个人理解的事情,这个hooks是一个为estimator服务的类,它有begin、after_create_session、before_run、after_run、end方法,分别用于在创建Session之前、创建Session之后、Session运行之前、Session运行之后以及Session即将关闭之前执行一些需要的操作。[参考代码](## 附录)

  3. steps: 模型训练的步数。如果是None, 模型将会一直训练下去,或者input_fn 遇到 tf.errors.OutOfRange的error或者StopIteration 的exception。steps 可以增量训练。例如,你先后调用了两次train(steps=10) ,那么总的训练步数是20步。如果在中间过程中发生了OutOfRangeStopIteration ,训练过程将在20步之前终止。如果你不想使用增量式的训练方式,请设置max_steps 参数. 如果设置了steps参数, max_steps必须设为 None

  4. max_steps: 模型训练的总步数,如果设为 None,模型一直训练直到 input_fn 发生tf.errors.OutOfRange error 或者StopIteration exception。如果设置了该参数,steps 必须设为 None。训练过程中如果遇到了 OutOfRange 或者 StopIteration ,训练过程将会在 max_steps 之前终止。 调用两次train(steps=100) 意味着总的训练步数为200,而两次调用train(max_steps=100) 只会训练100次,因为第一次的调用已经达到了最大训练次数。

  5. saving_listeners: CheckpointSaverListener 对象list. 用于checkpoint savings执行前后的立即回调过程。

代码解读,一句一句看:

    _estimator_api_gauge.get_cell('train').set(True)
    if self.config.task_type in (run_config.TaskType.EVALUATOR,
                                 run_config.TaskType.PS):
      raise ValueError(
          'Train has been called wrong configuration. Please use '
          'tf.estimator.train_and_evaluate which calls proper API according '
          'to given configuration. Current configuration: {}.'.format(
              self.config))

上面这几句是用来分布式训练时候使用的,个人用不到,我也不懂。暂时跳过。

    with context.graph_mode():
      if (steps is not None) and (max_steps is not None):
        raise ValueError('Can not provide both steps and max_steps.')
      if steps is not None and steps <= 0:
        raise ValueError('Must specify steps > 0, given: {}'.format(steps))
      if max_steps is not None and max_steps <= 0:
        raise ValueError(
            'Must specify max_steps > 0, given: {}'.format(max_steps))

      if max_steps is not None:
        start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
        if max_steps <= start_step:
          logging.info('Skipping training since max_steps has already saved.')
          return self

上面这几句是用来处理stepsmax_steps的,具体看代码,很好理解。

      hooks = _check_hooks_type(hooks)

上面这一句是用来检测传入的hooks是否符合SessionRunHook类型要求,点进_check_hooks_type,代码如下:

def _check_hooks_type(hooks):
  """Returns hooks if all are `SessionRunHook`, raises TypeError otherwise."""
  hooks = list(hooks or [])
  for h in hooks:
    if not isinstance(h, training.SessionRunHook):
      raise TypeError('Hooks must be a SessionRunHook, given: {}'.format(h))
  return hooks

代码很明朗。

hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))

上面这一句是将传入的steps或者max_steps转为StopAtStepHook的代码,这样就能用于训练过程控制。点进_convert_train_steps_to_hooks,代码如下:

  def _convert_train_steps_to_hooks(self, steps, max_steps):
    """Create hooks to run correct number of steps in training.

    Args:
      steps: number of steps to run during training.
      max_steps: maximum number of steps to be run during training. It'll be
        the maximum number of steps the model will train to after restoring
        from checkpoint even across multiple estimator.train calls.

    Returns:
      List of hooks to be passed to the estimator.
    """
    if steps is not None or max_steps is not None:
      if self._train_distribution:
        steps_per_run = getattr(
            self._train_distribution.extended, 'steps_per_run', 1)
        if steps_per_run > 1:
          return [basic_session_run_hooks._MultiStepStopAtStepHook(  # pylint: disable=protected-access
              steps, max_steps, steps_per_run)]
      return [training.StopAtStepHook(steps, max_steps)]
    else:
      return []

如前所说,不再复述。

saving_listeners = _check_listeners_type(saving_listeners)

上面这一句用来检测传入的saving_listeners的类型是否是CheckpointSaverListener类型的。点击_check_listeners_type,代码如下:

def _check_listeners_type(saving_listeners):
  """Check listeners type."""
  listeners = list(saving_listeners or [])
  for l in listeners:
    if not isinstance(l, training.CheckpointSaverListener):
      raise TypeError(
          'saving_listeners must be a list of CheckpointSaverListener, '
          'given: {}'.format(l))
  return listeners

CheckpointSaverListener是用于监听CheckpointSaverHook的接口,用于在checkpoint save操作之前和之后进行一系列定制的操作。

      loss = self._train_model(input_fn, hooks, saving_listeners)
      logging.info('Loss for final step: %s.', loss)
      return self

上面这三行是最后的三行,调用self._train_model方法训练模型,然后返回loss。点进去之后:

  def _train_model(self, input_fn, hooks, saving_listeners):
    if self._train_distribution:
      return self._train_model_distributed(input_fn, hooks, saving_listeners)
    else:
      return self._train_model_default(input_fn, hooks, saving_listeners)

这里是分了分布式训练和单机训练两种,分布式训练self._train_model_distributed在这里不看了。看一下单机训练self._train_model_default的代码:

  def _train_model_default(self, input_fn, hooks, saving_listeners):
    """Initiate training with `input_fn`, without `DistributionStrategies`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
        for callbacks that run immediately before or after checkpoint savings.

    Returns:
      Loss from training
    """
    worker_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)

      # Skip creating a read variable if _create_and_assert_global_step
      # returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
      if global_step_tensor is not None:
        training_util._get_or_create_global_step_read(g)  # pylint: disable=protected-access

      features, labels, input_hooks = (
          self._get_features_and_labels_from_input_fn(
              input_fn, ModeKeys.TRAIN))
      worker_hooks.extend(input_hooks)
      estimator_spec = self._call_model_fn(
          features, labels, ModeKeys.TRAIN, self.config)
      global_step_tensor = training_util.get_global_step(g)
      return self._train_with_estimator_spec(estimator_spec, worker_hooks,
                                             hooks, global_step_tensor,
                                             saving_listeners)

这里还是一句句解释,

    worker_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)

上面这三句是用来设置训练使用的设备信息、随机种子、创建global_step_tensor的。

      if global_step_tensor is not None:
        training_util._get_or_create_global_step_read(g)  # pylint: disable=protected-access

这一句是用来创建global_step_readtensor的,并将其添加至GLOBAL_STEP_READ_KEY

ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor)

下面一句:

      features, labels, input_hooks = (
          self._get_features_and_labels_from_input_fn(
              input_fn, ModeKeys.TRAIN))

这是用来返回由input_fn构建的features, labels以及创建的用来初始化它们的input_hooks。点击self._get_features_and_labels_from_input_fn,代码如下:

  def _get_features_and_labels_from_input_fn(self, input_fn, mode):
    """Extracts the `features` and labels from return values of `input_fn`."""
    return estimator_util.parse_input_fn_result(
        self._call_input_fn(input_fn, mode))

这里将self._call_input_fn(input_fn, mode)的结果传给了estimator_util.parse_input_fn_result,我们首先看self._call_input_fn(input_fn, mode)的代码:

  def _call_input_fn(self, input_fn, mode, input_context=None):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: `tf.estimator.ModeKeys`

    Returns:
      The return value of the passed `input_fn`, which should be one of:

        * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
            tuple `(features, labels)` with same constraints as below.
        * A tuple `(features, labels)`: Where `features` is a `Tensor` or a
          dictionary of string feature name to `Tensor` and `labels` is a
          `Tensor` or a dictionary of string label name to `Tensor`. Both
          `features` and `labels` are consumed by `model_fn`. They should
          satisfy the expectation of `model_fn` from inputs.

    Raises:
      ValueError: if `input_fn` takes invalid arguments.
    """
    input_fn_args = function_utils.fn_args(input_fn)
    kwargs = {}
    if 'mode' in input_fn_args:
      kwargs['mode'] = mode
    if 'params' in input_fn_args:
      kwargs['params'] = self.params
    if 'config' in input_fn_args:
      kwargs['config'] = self.config
    if input_context and 'input_context' in input_fn_args:
      logging.info('The `input_fn` accepts an `input_context` which will '
                   'be given by DistributionStrategy')
      kwargs['input_context'] = input_context
    with ops.device('/cpu:0'):
      return input_fn(**kwargs)

从代码可以看出来self._call_input_fn(input_fn, mode)首先将input_fn中的参数解析处理,剔除了不需要的参数,然后调用input_fn,并返回结果,因此,该函数返回至和input_fn一样,都是tf.data.Datasetobject或者是A tuple (features, labels),并将结果传给estimator_util.parse_input_fn_result,看一下代码:


def parse_input_fn_result(result):
  """Gets features, labels, and hooks from the result of an Estimator input_fn.

  Args:
    result: output of an input_fn to an estimator, which should be one of:

      * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
          tuple (features, labels) with same constraints as below.
      * A tuple (features, labels): Where `features` is a `Tensor` or a
        dictionary of string feature name to `Tensor` and `labels` is a
        `Tensor` or a dictionary of string label name to `Tensor`. Both
        `features` and `labels` are consumed by `model_fn`. They should
        satisfy the expectation of `model_fn` from inputs.

  Returns:
    Tuple of features, labels, and input_hooks, where features are as described
    above, labels are as described above or None, and input_hooks are a list
    of SessionRunHooks to be included when running.

  Raises:
    ValueError: if the result is a list or tuple of length != 2.
  """
  input_hooks = []
  if isinstance(result, dataset_ops.DatasetV2):
    iterator = dataset_ops.make_initializable_iterator(result)
    input_hooks.append(_DatasetInitializerHook(iterator))
    result = iterator.get_next()
  return parse_iterator_result(result) + (input_hooks,)

读代码:

  if isinstance(result, dataset_ops.DatasetV2):
    iterator = dataset_ops.make_initializable_iterator(result)
    input_hooks.append(_DatasetInitializerHook(iterator))
    result = iterator.get_next()

这一句干了两个事情,第一个事情是判断传入的result是否是dataset_ops.DatasetV2,如果是,则创建一个tf.compat.v1.data.Iterator,这样可以将dataset_ops.DatasetV2转化为dataset_ops.DatasetV1,同时这个iterator是未经初始化的,因此第二事情就是创建一个input_hooks,用于将iterator初始化。DatasetV2和DatasetV1的区别参考https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/data/ops/dataset_ops.py:描述为:

    This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
    take anything in its constructor whereas in the DatasetV2, we expect
    subclasses to create a variant_tensor and pass it in to the super() call.

然后是:

return parse_iterator_result(result) + (input_hooks,)

点进去parse_iterator_result

def parse_iterator_result(result):
  """Gets features, labels from result."""
  if isinstance(result, (list, tuple)):
    if len(result) != 2:
      raise ValueError(
          'input_fn should return (features, labels) as a len 2 tuple.')
    return result[0], result[1]
  return result, None

该函数用于将result中的feature和label解析出来,当label没有的时候,补上None。

因此:

features, labels, input_hooks = (
          self._get_features_and_labels_from_input_fn(
              input_fn, ModeKeys.TRAIN))

这一句已经解释完了。

下一句:

worker_hooks.extend(input_hooks)

就是将刚刚创建的用于初始化iterator的hooks添加进worker_hooks

下面一句:

      estimator_spec = self._call_model_fn(
          features, labels, ModeKeys.TRAIN, self.config)

点进去self._call_model_fn

 def _call_model_fn(self, features, labels, mode, config):
    """Calls model function.

    Args:
      features: features dict.
      labels: labels dict.
      mode: `tf.estimator.ModeKeys`
      config: `tf.estimator.RunConfig`

    Returns:
      An `tf.estimator.EstimatorSpec` object.

    Raises:
      ValueError: if `model_fn` returns invalid objects.
    """
    model_fn_args = function_utils.fn_args(self._model_fn)
    kwargs = {}
    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    else:
      if labels is not None:
        raise ValueError(
            'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = mode
    if 'params' in model_fn_args:
      kwargs['params'] = self.params
    if 'config' in model_fn_args:
      kwargs['config'] = config

    logging.info('Calling model_fn.')
    model_fn_results = self._model_fn(features=features, **kwargs)
    logging.info('Done calling model_fn.')

    if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):
      raise ValueError('model_fn should return an EstimatorSpec.')

    return model_fn_results

首先解析了model_fn的参数,然后调用model_fn,并检查返回结果是否是EstimatorSpec类型,是的话直接返回。因此,self._call_model_fn就是用来生成用户自定义的model_fn返回的模型。

继续看,下面一句:

global_step_tensor = training_util.get_global_step(g)

这个就是获取训练的步数。

下一句:

return self._train_with_estimator_spec(estimator_spec, worker_hooks,
                                             hooks, global_step_tensor,
                                             saving_listeners)

继续跟进去,点击self._train_with_estimator_spec

  def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                 global_step_tensor, saving_listeners):
    """Train a model with the given Estimator Spec."""
    if self._warm_start_settings:
      logging.info('Warm-starting with WarmStartSettings: %s' %
                   (self._warm_start_settings,))
      warm_starting_util.warm_start(*self._warm_start_settings)
    # Check if the user created a loss summary, and add one if they didn't.
    # We assume here that the summary is called 'loss'. If it is not, we will
    # make another one with the name 'loss' to ensure it shows up in the right
    # graph in TensorBoard.
    if not any([x.op.name == 'loss'
                for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
      summary.scalar('loss', estimator_spec.loss)
    ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
    worker_hooks.extend(hooks)
    worker_hooks.append(
        training.NanTensorHook(estimator_spec.loss)
    )
    if self._config.log_step_count_steps is not None:
      worker_hooks.append(
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=self._config.log_step_count_steps)
      )
    worker_hooks.extend(estimator_spec.training_hooks)

    if not (estimator_spec.scaffold.saver or
            ops.get_collection(ops.GraphKeys.SAVERS)):
      ops.add_to_collection(
          ops.GraphKeys.SAVERS,
          training.Saver(
              sharded=True,
              max_to_keep=self._config.keep_checkpoint_max,
              keep_checkpoint_every_n_hours=(
                  self._config.keep_checkpoint_every_n_hours),
              defer_build=True,
              save_relative_paths=True))

    if (self._config.cluster_spec and type(
        self._train_distribution).__name__ in ('CollectiveAllReduceStrategy',
                                               'CollectiveAllReduceStrategyV1',
                                               'MultiWorkerMirroredStrategy')):
      return self._train_with_estimator_spec_distributed(
          estimator_spec, worker_hooks, saving_listeners)

    chief_hooks = []
    all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
    saver_hooks = [
        h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
    if (self._config.save_checkpoints_secs or
        self._config.save_checkpoints_steps):
      if not saver_hooks:
        chief_hooks = [
            training.CheckpointSaverHook(
                self._model_dir,
                save_secs=self._config.save_checkpoints_secs,
                save_steps=self._config.save_checkpoints_steps,
                scaffold=estimator_spec.scaffold)
        ]
        saver_hooks = [chief_hooks[0]]
    if saving_listeners:
      if not saver_hooks:
        raise ValueError(
            'There should be a CheckpointSaverHook to use saving_listeners. '
            'Please set one of the RunConfig.save_checkpoints_steps or '
            'RunConfig.save_checkpoints_secs.')
      else:
        # It is expected to have one CheckpointSaverHook. If multiple, we pick
        # up the first one to add listener.
        saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access

    # Add summary hooks to worker 0 if we are running with a master, to ensure
    # that summaries are written at correct intervals even with long-running
    # evaluations.
    save_summary_steps = self._config.save_summary_steps
    log_step_count_steps = self._config.log_step_count_steps

    # Check existence of appropriate cluster spec fields, as well as master and
    # worker nodes. As master also performs evaluation, summary writing must
    # occur on a different node. The presence of a worker is also checked to
    # prevent reassigning hooks for single-replica jobs with just a master node.
    if (self._config.cluster_spec and self._config.cluster_spec.jobs and
        (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
        (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
      # Update config values to prevent the default hooks from being created on
      # the master or other workers.
      save_summary_steps = 0
      log_step_count_steps = None

      if (self._config.task_type == run_config.TaskType.WORKER and
          self._config.task_id == 0):
        if (self._config.save_summary_steps and
            self._config.save_summary_steps > 0):
          worker_hooks.append(
              training.SummarySaverHook(
                  save_steps=self._config.save_summary_steps,
                  output_dir=self._config.model_dir,
                  scaffold=estimator_spec.scaffold))

        if (self._config.log_step_count_steps and
            self._config.log_step_count_steps > 0):
          worker_hooks.append(
              training.StepCounterHook(
                  every_n_steps=self._config.log_step_count_steps,
                  output_dir=self._config.model_dir))

    with training.MonitoredTrainingSession(
        master=self._config.master,
        is_chief=self._config.is_chief,
        checkpoint_dir=self._model_dir,
        scaffold=estimator_spec.scaffold,
        hooks=worker_hooks,
        chief_only_hooks=(
            tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
        save_checkpoint_secs=0,  # Saving is handled by a hook.
        save_summaries_steps=save_summary_steps,
        config=self._session_config,
        log_step_count_steps=log_step_count_steps) as mon_sess:
      loss = None
      any_step_done = False
      while not mon_sess.should_stop():
        _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
        any_step_done = True
    if not any_step_done:
      logging.warning('Training with estimator made no steps. '
                      'Perhaps input is empty or misspecified.')
    return loss

这个代码太长了,慢慢看。。。

首先:

    if self._warm_start_settings:
      logging.info('Warm-starting with WarmStartSettings: %s' %
                   (self._warm_start_settings,))
      warm_starting_util.warm_start(*self._warm_start_settings)

这用来热启动模型的,比如你给定了预训练的模型地址,这样模型可以接着之前的结果继续训练。

下面一句:

# Check if the user created a loss summary, and add one if they didn't.
    # We assume here that the summary is called 'loss'. If it is not, we will
    # make another one with the name 'loss' to ensure it shows up in the right
    # graph in TensorBoard.
    if not any([x.op.name == 'loss'
                for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
      summary.scalar('loss', estimator_spec.loss)
    ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)

上面这几句看英文解释就好了,很好理解。

下面几句:

worker_hooks.extend(hooks)
    worker_hooks.append(
        training.NanTensorHook(estimator_spec.loss)
    )

这是将之前传进来的用户定义的hooks添加到worker_hooks里面,同时增加一个判断loss是否为Nan的hook,这个hook当loss为Nan的时候会停止训练。

下面几句:

    if self._config.log_step_count_steps is not None:
      worker_hooks.append(
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=self._config.log_step_count_steps)
      )
    worker_hooks.extend(estimator_spec.training_hooks)

上面这几句是添加了发log的Hook和创建的estimator_spec的training_hooks。

继续看:

    if not (estimator_spec.scaffold.saver or
            ops.get_collection(ops.GraphKeys.SAVERS)):
      ops.add_to_collection(
          ops.GraphKeys.SAVERS,
          training.Saver(
              sharded=True,
              max_to_keep=self._config.keep_checkpoint_max,
              keep_checkpoint_every_n_hours=(
                  self._config.keep_checkpoint_every_n_hours),
              defer_build=True,
              save_relative_paths=True))

上面这几句是用来检测是否存在保存模型的savor,不存在则创建并添加到ops.GraphKeys.SAVERS里面。

继续:

    if (self._config.cluster_spec and type(
        self._train_distribution).__name__ in ('CollectiveAllReduceStrategy',
                                               'CollectiveAllReduceStrategyV1',
                                               'MultiWorkerMirroredStrategy')):
      return self._train_with_estimator_spec_distributed(
          estimator_spec, worker_hooks, saving_listeners)

这是用来分布式训练的。暂时不看。

下面几句:

    chief_hooks = []
    all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
    saver_hooks = [
        h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
    if (self._config.save_checkpoints_secs or
        self._config.save_checkpoints_steps):
      if not saver_hooks:
        chief_hooks = [
            training.CheckpointSaverHook(
                self._model_dir,
                save_secs=self._config.save_checkpoints_secs,
                save_steps=self._config.save_checkpoints_steps,
                scaffold=estimator_spec.scaffold)
        ]
        saver_hooks = [chief_hooks[0]]
    if saving_listeners:
      if not saver_hooks:
        raise ValueError(
            'There should be a CheckpointSaverHook to use saving_listeners. '
            'Please set one of the RunConfig.save_checkpoints_steps or '
            'RunConfig.save_checkpoints_secs.')
      else:
        # It is expected to have one CheckpointSaverHook. If multiple, we pick
        # up the first one to add listener.
        saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access

上面这一些是用来合并hooks的。

看下面的:

# Add summary hooks to worker 0 if we are running with a master, to ensure
    # that summaries are written at correct intervals even with long-running
    # evaluations.
    save_summary_steps = self._config.save_summary_steps
    log_step_count_steps = self._config.log_step_count_steps

    # Check existence of appropriate cluster spec fields, as well as master and
    # worker nodes. As master also performs evaluation, summary writing must
    # occur on a different node. The presence of a worker is also checked to
    # prevent reassigning hooks for single-replica jobs with just a master node.
    if (self._config.cluster_spec and self._config.cluster_spec.jobs and
        (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
        (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
      # Update config values to prevent the default hooks from being created on
      # the master or other workers.
      save_summary_steps = 0
      log_step_count_steps = None

      if (self._config.task_type == run_config.TaskType.WORKER and
          self._config.task_id == 0):
        if (self._config.save_summary_steps and
            self._config.save_summary_steps > 0):
          worker_hooks.append(
              training.SummarySaverHook(
                  save_steps=self._config.save_summary_steps,
                  output_dir=self._config.model_dir,
                  scaffold=estimator_spec.scaffold))

        if (self._config.log_step_count_steps and
            self._config.log_step_count_steps > 0):
          worker_hooks.append(
              training.StepCounterHook(
                  every_n_steps=self._config.log_step_count_steps,
                  output_dir=self._config.model_dir))

这是用来添加summary_hook,还有计算step的hook的。

最后一部分:

    with training.MonitoredTrainingSession(
        master=self._config.master,
        is_chief=self._config.is_chief,
        checkpoint_dir=self._model_dir,
        scaffold=estimator_spec.scaffold,
        hooks=worker_hooks,
        chief_only_hooks=(
            tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
        save_checkpoint_secs=0,  # Saving is handled by a hook.
        save_summaries_steps=save_summary_steps,
        config=self._session_config,
        log_step_count_steps=log_step_count_steps) as mon_sess:
      loss = None
      any_step_done = False
      while not mon_sess.should_stop():
        _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
        any_step_done = True
    if not any_step_done:
      logging.warning('Training with estimator made no steps. '
                      'Perhaps input is empty or misspecified.')
    return loss

这里,使用MonitoredTrainingSession,进行训练,一直训练直到结束。返回loss。这里的train_op是自定义的,loss可以自定义,也可以不定义,estimator_spec会自动生成。

这样,到此,tf.estimator.Estimator之train方法就看完了,如果有问题欢迎评论。

附录

  1. SessionRunHook源码:
class SessionRunHook(object):
  """Hook to extend calls to MonitoredSession.run()."""

  def begin(self):
    """Called once before using the session.

    When called, the default graph is the one that will be launched in the
    session.  The hook can modify the graph by adding new operations to it.
    After the `begin()` call the graph will be finalized and the other callbacks
    can not modify the graph anymore. Second call of `begin()` on the same
    graph, should not change the graph.
    """
    pass

  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    """Called when new TensorFlow session is created.

    This is called to signal the hooks that a new session has been created. This
    has two essential differences with the situation in which `begin` is called:

    * When this is called, the graph is finalized and ops can no longer be added
        to the graph.
    * This method will also be called as a result of recovering a wrapped
        session, not only at the beginning of the overall session.

    Args:
      session: A TensorFlow Session that has been created.
      coord: A Coordinator object which keeps track of all threads.
    """
    pass

  def before_run(self, run_context):  # pylint: disable=unused-argument
    """Called before each call to run().

    You can return from this call a `SessionRunArgs` object indicating ops or
    tensors to add to the upcoming `run()` call.  These ops/tensors will be run
    together with the ops/tensors originally passed to the original run() call.
    The run args you return can also contain feeds to be added to the run()
    call.

    The `run_context` argument is a `SessionRunContext` that provides
    information about the upcoming `run()` call: the originally requested
    op/tensors, the TensorFlow Session.

    At this point graph is finalized and you can not add ops.

    Args:
      run_context: A `SessionRunContext` object.

    Returns:
      None or a `SessionRunArgs` object.
    """
    return None

  def after_run(self,
                run_context,  # pylint: disable=unused-argument
                run_values):  # pylint: disable=unused-argument
    """Called after each call to run().

    The `run_values` argument contains results of requested ops/tensors by
    `before_run()`.

    The `run_context` argument is the same one send to `before_run` call.
    `run_context.request_stop()` can be called to stop the iteration.

    If `session.run()` raises any exceptions then `after_run()` is not called.

    Args:
      run_context: A `SessionRunContext` object.
      run_values: A SessionRunValues object.
    """
    pass

  def end(self, session):  # pylint: disable=unused-argument
    """Called at the end of session.

    The `session` argument can be used in case the hook wants to run final ops,
    such as saving a last checkpoint.

    If `session.run()` raises exception other than OutOfRangeError or
    StopIteration then `end()` is not called.
    Note the difference between `end()` and `after_run()` behavior when
    `session.run()` raises OutOfRangeError or StopIteration. In that case
    `end()` is called but `after_run()` is not called.

    Args:
      session: A TensorFlow Session that will be soon closed.
    """
    pass
发布了97 篇原创文章 · 获赞 55 · 访问量 13万+

猜你喜欢

转载自blog.csdn.net/voidfaceless/article/details/103023910
今日推荐