tensorflow estimator 与 model_fn 是这样沟通的

在自定义估计器过程中,搞清Estimator 与model_fn 及其他参数之间的关系十分中重要!总结一下,就是
estimator 拿着获取到的参数往model_fn里面灌,model_fn 是作为用数据的关键用户。
与scikit-learn和spark中的各种估计器相比,tensorflow的估计器抽象程度更高,因为他将各种由超参数知道构建的
模型作为参数传入,estimator的结构和定义不会因为模型的变化带来特别大的变化;而spark,scikit-learn中,估计器
往往因算法不同而有不同构造,TensorFlow的参数化程度更高,有更高自由度,因而参数管理就与前两者有所不同!

总之,Estimator要使用传入的数据就必须了解传入的数据,java有种类型控制,Python中鸭子判断检查,或者有元数据帮忙了解传入的数据,
或者大家有默契约定,或者有明显的协议!Esimator和mode_fn之间没有强制约束,靠大家默契约定,约定内容就在下面的英文描述中。
Depending on the value of mode, different arguments are required. Namely

* For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`.
* For `mode == ModeKeys.EVAL`: required field is `loss`.
* For `mode == ModeKeys.PREDICT`: required fields are `predictions`.

class Estimator(object):
"""Estimator class to train and evaluate TensorFlow models.

The Estimator object wraps a model which is specified by a model_fn,
which, given inputs and a number of other parameters, returns the ops
necessary to perform training, evaluation, or predictions.

All outputs (checkpoints, event files, etc.) are written to model_dir, or a
subdirectory thereof. If model_dir is not set, a temporary directory is
used.

The config argument can be passed tf.estimator.RunConfig object containing
information about the execution environment. It is passed on to the
model_fn, if the model_fn has a parameter named "config" (and input
functions in the same manner). If the config parameter is not passed, it is
instantiated by the Estimator. Not passing config means that defaults useful
for local execution are used. Estimator makes config available to the model
(for instance, to allow specialization based on the number of workers
available), and also uses some of its fields to control internals, especially
regarding checkpointing.

The params argument contains hyperparameters. It is passed to the
model_fn, if the model_fn has a parameter named "params", and to the input
functions in the same manner. Estimator only passes params along, it does
not inspect it. The structure of params is therefore entirely up to the
developer.

None of Estimator's methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use model_fn to configure
the base class, and may add methods implementing specialized functionality.

@compatibility(eager)
Calling methods of Estimator will work while eager execution is enabled.
However, the model_fn and input_fn is not executed eagerly, Estimator
will switch to graph model before calling all user-provided functions (incl.
hooks), so their code has to be compatible with graph mode execution. Note
that input_fn code using tf.data generally works in both graph and eager
modes.
@end_compatibility
"""

def init(self, model_fn, model_dir=None, config=None, params=None,
warm_start_from=None):
"""Constructs an Estimator instance.

See [estimators](https://tensorflow.org/guide/estimators) for more
information.

To warm-start an `Estimator`:

```python
estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    warm_start_from="/path/to/checkpoint/dir")
```

For more details on warm-start configuration, see
`tf.estimator.WarmStartSettings`.

Args:
  model_fn: Model function. Follows the signature:

    * Args:

      * `features`: This is the first item returned from the `input_fn`
             passed to `train`, `evaluate`, and `predict`. This should be a
             single `tf.Tensor` or `dict` of same.
      * `labels`: This is the second item returned from the `input_fn`
             passed to `train`, `evaluate`, and `predict`. This should be a
             single `tf.Tensor` or `dict` of same (for multi-head models).
             If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
             be passed. If the `model_fn`'s signature does not accept
             `mode`, the `model_fn` must still be able to handle
             `labels=None`.
      * `mode`: Optional. Specifies if this training, evaluation or
             prediction. See `tf.estimator.ModeKeys`.
      * `params`: Optional `dict` of hyperparameters.  Will receive what
             is passed to Estimator in `params` parameter. This allows
             to configure Estimators from hyper parameter tuning.
      * `config`: Optional `estimator.RunConfig` object. Will receive what
             is passed to Estimator as its `config` parameter, or a default
             value. Allows setting up things in your `model_fn` based on
             configuration such as `num_ps_replicas`, or `model_dir`.

    * Returns:
      `tf.estimator.EstimatorSpec`

  model_dir: Directory to save model parameters, graph and etc. This can
    also be used to load checkpoints from the directory into an estimator to
    continue training a previously saved model. If `PathLike` object, the
    path will be resolved. If `None`, the model_dir in `config` will be used
    if set. If both are set, they must be same. If both are `None`, a
    temporary directory will be used.
  config: `estimator.RunConfig` configuration object.
  params: `dict` of hyper parameters that will be passed into `model_fn`.
          Keys are names of parameters, values are basic python types.
  warm_start_from: Optional string filepath to a checkpoint or SavedModel to
                   warm-start from, or a `tf.estimator.WarmStartSettings`
                   object to fully configure warm-starting.  If the string
                   filepath is provided instead of a
                   `tf.estimator.WarmStartSettings`, then all variables are
                   warm-started, and it is assumed that vocabularies
                   and `tf.Tensor` names are unchanged.

Raises:
  ValueError: parameters of `model_fn` don't match `params`.
  ValueError: if this is called via a subclass and if that class overrides
    a member of `Estimator`.
"""

猜你喜欢

转载自www.cnblogs.com/wdmx/p/10010433.html