Tensorflow高级API介绍:Estimator

1、Estimator架构

2、Estimator使用步骤

3、从源码理解Estimator

(1)Estimator的源码如下:

(2)构建model_fn

(3)什么是 tf.estimator.EstimatorSpec

(4)指定config


1、Estimator架构

            

    可以看到Estimator是属于High level的API,而Mid-level API分别是:

  • Layers:用来构建网络结构
  • Datasets: 用来构建数据读取pipeline
  • Metrics:用来评估网络性能

    可以看到如果使用Estimator,我们只需要关注这三个部分即可,而不用再关心一些太细节的东西,另外也不用再使用烦人的Session了。

2、Estimator使用步骤

         

  • 创建一个或多个输入函数,即input_fn;
  • 定义模型的特征列,即feature_columns;
  • 实例化 Estimator,指定特征列和各种超参数;
  • 在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源。(train, evaluate, predict)

 3、从源码理解Estimator

(1)Estimator的源码如下:

class Estimator(object):
  def __init__(self, 
               model_fn, 
               model_dir=None, 
               config=None, 
               params=None, 
               warm_start_from=None):
  ...

- model_dir: 指定checkpoints和其他日志存放的路径。
- model_fn: 这个是需要我们自定义的网络模型函数,后面详细介绍
- config: 用于控制内部和checkpoints等,如果model_fn函数也定义config这个变量,则会将config传给model_fn
- params: 该参数的值会传递给model_fn。
- warm_start_from: 指定checkpoint路径,会导入该checkpoint开始训练

(2)构建model_fn

def model_fn(
    features,    # This is batch_features from input_fn,`Tensor` or dict of `Tensor` (depends on data passed to `fit`).
    labels,     # This is batch_labels from input_fn
    mode,       # An instance of tf.estimator.ModeKeys
    params,     # Additional configuration
    config=None
   ):
 ...

- features和labels两个参数是从输入函数中返回的特征和标签批次,也就是说,features 和 labels 是模型将使用的数据;
- params 是一个字典,它可以传入许多参数用来构建网络或者定义训练方式等。例如通过设置params['n_classes']来定义最终输出节点的个数等。
- config 通常用来控制checkpoint或者分布式什么,这里不深入研究。
- mode 参数表示调用程序是请求训练、评估还是预测,分别通过tf.estimator.ModeKeys.TRAIN/EVAL/PREDICT 来定义。
    另外通过观察DNNClassifier的源代码可以看到,mode这个参数并不用手动传入,因为Estimator会自动调整。
    例如当你调用estimator.train(...)的时候,mode则会被赋值tf.estimator.ModeKeys.TRAIN。

     model_fn需要对于不同的模式提供不同的处理方式,并且都需要返回一个 tf.estimator.EstimatorSpec 的实例 。通俗解释就是:模型有训练,验证和测试三种阶段,而且对于不同模式,对数据有不同的处理方式。例如在训练阶段,我们需要将数据喂给模型,模型基于输入数据给出预测值,然后我们在通过预测值和真实值计算出loss,最后用loss更新网络参数,而在评估阶段,我们则不需要反向传播更新网络参数,换句话说,mdoel_fn需要对三种模式设置三套代码。另外model_fn需要返回什么东西呢?Estimator规定model_fn需要返回 tf.estimator.EstimatorSpec,这样它才好更具一般化的进行处理。

(3)什么是 tf.estimator.EstimatorSpec

    它是一个class(类),是定义在model_fn中的,并且model_fn返回的也是它的一个实例,这个实例是用来初始化Estimator类的。其源代码如下:

class EstimatorSpec():
  def __new__(cls,
              mode,
              predictions=None,
              loss=None,
              train_op=None,
              eval_metric_ops=None,
              export_outputs=None,
              training_chief_hooks=None,
              training_hooks=None,
              scaffold=None,
              evaluation_hooks=None,
              prediction_hooks=None):
...
- mode:一个ModeKeys,指定是training(训练)、evaluation(计算)还是prediction(预测).
- predictions:Predictions Tensor or dict of Tensor.
- loss:Training loss Tensor. Must be either scalar, or with shape [1].
- train_op:适用于训练的步骤.
- eval_metric_ops: Dict of metric results keyed by name. The values of the dict can be one of the following:
  (1) instance of Metric class.
  (2) Results of calling a metric function, namely a (metric_tensor, update_op) tuple. 
      metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). 
      For example, it should not trigger the update_op or requires any input fetching.

    不同模式需要传入不同参数:     

  • 对于mode == ModeKeys.TRAIN:必填字段是loss和train_op.
  • 对于mode == ModeKeys.EVAL:必填字段是loss.
  • 对于mode == ModeKeys.PREDICT:必填字段是predictions.

  1)最简单的情况: predict,只需要传入mode和predictions    

predicted_classes = tf.argmax(logits, 1)
if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'class_ids': predicted_classes[:, tf.newaxis],
        'probabilities': tf.nn.softmax(logits),
        'logits': logits,
    }
    return tf.estimator.EstimatorSpec(mode, predictions=predictions)

  2)评估模式:eval 需要传入mode, loss, eval_metric_ops

    如果调用 Estimator 的 evaluate 方法,则 model_fn 会收到 mode = ModeKeys.EVAL。在这种情况下,模型函数必须返回一个包含模型损失和一个或多个指标(可选)的 tf.estimator.EstimatorSpec。

# Compute loss.
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
# Compute evaluation metrics.
accuracy = tf.metrics.accuracy(
              labels=labels,
              predictions=predicted_classes,
              name='acc_op')
metrics = {'accuracy': accuracy}
if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(
              mode, 
              loss=loss, 
              eval_metric_ops=metrics)

  3)训练模式:train 需要传入mode,loss,train_op

# Compute loss.
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

  4)通用模式: model_fn可以填充独立于模式的所有参数.在这种情况下,Estimator将忽略某些参数.在eval和infer模式中,train_op将被忽略.例子如下:

def model_fn(mode, features, labels):
  predictions = ...
  loss = ...
  train_op = ...
  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=predictions,
      loss=loss,
      train_op=train_op)

(4)指定config

    此处的config需要传入tf.estimator.RunConfig,其源代码如下:

class RunConfig(object):
  """This class specifies the configurations for an `Estimator` run."""

  def __init__(self,
               model_dir=None,
               tf_random_seed=None,
               save_summary_steps=100,
               save_checkpoints_steps=_USE_DEFAULT,
               save_checkpoints_secs=_USE_DEFAULT,
               session_config=None,
               keep_checkpoint_max=5,
               keep_checkpoint_every_n_hours=10000,
               log_step_count_steps=100,
               train_distribute=None,
               device_fn=None,
               protocol=None,
               eval_distribute=None,
               experimental_distribute=None,
               experimental_max_worker_delay_secs=None,
               session_creation_timeout_secs=7200):
...
- model_dir: 指定存储模型参数,graph等的路径;
- save_summary_steps: 每隔多少step就存一次Summaries;
- save_checkpoints_steps: 每隔多少个step就存一次checkpoint;
- save_checkpoints_secs: 每隔多少秒就存一次checkpoint,不可以和save_checkpoints_steps同时指定;
    如果二者都不指定,则使用默认值,即每600秒存一次。如果二者都设置为None,则不存checkpoints。
- keep_checkpoint_max:指定最多保留多少个checkpoints,也就是说当超出指定数量后会将旧的checkpoint删除。
    当设置为None或0时,则保留所有checkpoints;
- keep_checkpoint_every_n_hours:保存checkpoint文件的频率;
- log_step_count_steps:该参数的作用是,(相对于总的step数而言)指定每隔多少step就记录一次训练过程中loss的值,
    同时也会记录global steps/s,通过这个也可以得到模型训练的速度快慢;

猜你喜欢

转载自blog.csdn.net/MOU_IT/article/details/103822448