tf.estimator.Estimator

tf.estimator.Estimator

简单介绍

是一个class 所以需要初始化,作用是用来 训练和评价 tensorflow 模型的
Estimator对象包装由一个名为model_fn函数指定的模型,model_fn在给定输入和许多其他参数的情况下,返回执行训练、评估或预测所需的操作。所有输出(checkpoints, event files, etc.等)都写入model_dir或其子目录。如果没有设置model_dir,则使用临时目录。

初始化

__init__(
    model_fn,
    model_dir=None,
    config=None,
    params=None,
    warm_start_from=None
)

'''



Args:
	model_fn: Model function. Follows the signature:
	Args:
		features:  是从 input_fn中返回的词典tensor 或者 单个tensor ;其实质就是模型的输入(以前我们都是用tf.placeholder输入的,这里使用input_fn 函数返回)  This is the first item returned from the input_fn
		labels:  是从 input_fn中返回的词典tensor 或者 单个tensor,注意,如果mode=tf.estimator.ModeKeys.PREDICT(就是在预测的时候), labels将会被设置为None  This is the second item returned from the input_fn
		mode: Optional. Specifies if this training, evaluation or prediction. See tf.estimator.ModeKeys.
		params: Optional dict of hyperparameters.接受初始化Estimator实例时的参数params 
		config: Optional estimator.RunConfig object.接受初始化Estimator实例时的参数config  或者一个默认的值. Allows setting up things in your model_fn based on configuration such as num_ps_replicas, or model_dir.
		Returns: tf.estimator.EstimatorSpec  这里一定要注意 返回的是EstimatorSpec实例


	model_dir: 输出路径,有关模型的输出的一切东西,全部输出在这里
	
	config: 这个是一个类,是官方固定的配置参数,如果用户觉得,不能满足使用,需要添加自己的参数,可以使用下面的这个参数params
	
	params: dict of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types.
	
	warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a tf.estimator.WarmStartSettings object to fully configure warm-starting. If the string filepath is provided instead of a tf.estimator.WarmStartSettings, then all variables are warm-started, and it is assumed that vocabularies and tf.Tensor names are unchanged.
'''

重点圈出

The config argument can be passed tf.estimator.RunConfig object containing information about the execution environment. It is passed on to the model_fn, if the model_fn has a parameter named “config” (and input functions in the same manner). If the config parameter is not passed, it is instantiated by the Estimator. Not passing config means that defaults useful for local execution are used. Estimator makes config available to the model (for instance, to allow specialization based on the number of workers available), and also uses some of its fields to control internals, especially regarding checkpointing.

The params argument contains hyperparameters. It is passed to the model_fn, if the model_fn has a parameter named “params”, and to the input functions in the same manner. Estimator only passes params along, it does not inspect it. The structure of params is therefore entirely up to the developer.

方法

train 方法

从input_fn 获取数据,用来训练模型

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

'''
Args:
	input_fn: A function that provides input data for training as minibatches. See Premade Estimators for more information. The function should construct and return one of the following: * A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below. * A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor. Both features and labels are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.
	hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop.
	steps: Number of steps for which to train the model. If None, train forever or train until input_fn generates the tf.errors.OutOfRange error or StopIteration exception. steps works incrementally. If you call two times train(steps=10) then training occurs in total 20 steps. If OutOfRange or StopIteration occurs in the middle, training stops before 20 steps. If you don't want to have incremental behavior please set max_steps instead. If set, max_steps must be None.
	max_steps: Number of total steps for which to train model. If None, train forever or train until input_fn generates the tf.errors.OutOfRange error or StopIteration exception. If set, steps must be None. If OutOfRange or StopIteration occurs in the middle, training stops before max_steps steps. Two calls to train(steps=100) means 200 training iterations. On the other hand, two calls to train(max_steps=100) means that the second call will not do any iteration since first call did all 100 steps.
	saving_listeners: list of CheckpointSaverListener objects. Used for callbacks that run immediately before or after checkpoint savings.
Returns:
	self, for chaining.

'''

主要参数说明

input_fn:是一个为训练提供输入数据的函数(每次提供一个batch_size的数据),其返回的是的格式是(features,labels),正好作为mode_fn的输入,其返回的格式应该是下列之一:

  1. tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels)
  2. A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor

max_steps:最大训练多少step(也就是训练多少个batch_size),当我们暂停后,继续训练程序会检测目前已经训练的步数是否大于max_steps若大于等于,那么就不会继续训练(On the other hand, two calls to train(max_steps=100) means that the second call will not do any iteration since first call did all 100 steps.

step:会在原来的基础上,继续“增长式”训练,例如你调用了两次train(input_fn,step=10),那么模型就相当于训练了20个迭代

evaluate 方法

Evaluates the model given evaluation data input_fn.
For each step, calls input_fn, which returns one batch of data. Evaluates until: - steps batches are processed, or - input_fn raises an end-of-input exception获取input_fn返回的数据并输入到模型中,用来评价模型每一步都调用一次input_fn,其返回one batch of data,知道等于steps 或者input_fn raises an end-of-input exception

evaluate(
    input_fn,
    steps=None,
    hooks=None,
    checkpoint_path=None,
    name=None
)

'''
Args:
		input_fn: A function that constructs the input data for evaluation. See Premade Estimators for more information. The function should construct and return one of the following: * A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below. * A tuple (features, labels): Where features is a tf.Tensor or a dictionary of string feature name to Tensor and labels is a Tensor or a dictionary of string label name to Tensor. Both features and labels are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.
		steps: Number of steps for which to evaluate model. If None, evaluates until input_fn raises an end-of-input exception.
		hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the evaluation call.
		checkpoint_path: Path of a specific checkpoint to evaluate. If None, the latest checkpoint in model_dir is used. If there are no checkpoints in model_dir, evaluation is run with newly initialized Variables instead of ones restored from checkpoint.
		name: Name of the evaluation if user needs to run multiple evaluations on different data sets, such as on training data vs test data. Metrics for different evaluations are saved in separate folders, and appear separately in tensorboard.

Returns:
		A dict containing the evaluation metrics specified in model_fn keyed by name, as well as an entry global_step which contains the value of the global step for which this evaluation was performed. For canned estimators, the dict contains the loss (mean loss per mini-batch) and the average_loss (mean loss per sample). Canned classifiers also return the accuracy. Canned regressors also return the label/mean and the prediction/mean.
'''

参数说明

具体的参数和train方法类似,就不说了,这里主要说一下 这个方法的返回(return)
返回的是一个词典,是在mode_fn中提前指定好的,同时还会返回执行了多少step
例如在model_fn函数中一般有如下类似定义:

    estim_specs=tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=pred_classes,
        loss=loss_op,
        train_op=train_op,
        eval_metric_ops={"accuracy":acc_op})

中的 eval_metric_ops={“accuracy”:acc_op}),最后会输出类似这种

{'accuracy': 0.9192, 'loss': 0.28470048, 'global_step': 1000}

predict方法

predict(
    input_fn,
    predict_keys=None,
    hooks=None,
    checkpoint_path=None,
    yield_single_examples=True
)

'''
Args:
	input_fn: A function that constructs the features. Prediction continues until input_fn raises an end-of-input exception (tf.errors.OutOfRangeError or StopIteration). See Premade Estimators for more information. The function should construct and return one of the following:
	
	A tf.data.Dataset object: Outputs of Dataset object must have same constraints as below.
	features: A tf.Tensor or a dictionary of string feature name to Tensor. features are consumed by model_fn. They should satisfy the expectation of model_fn from inputs.
	A tuple, in which case the first item is extracted as features.
	predict_keys: list of str, name of the keys to predict. It is used if the tf.estimator.EstimatorSpec.predictions is a dict. If predict_keys is used then rest of the predictions will be filtered from the dictionary. If None, returns all.
	
	hooks: List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the prediction call.
	
	checkpoint_path: Path of a specific checkpoint to predict. If None, the latest checkpoint in model_dir is used. If there are no checkpoints in model_dir, prediction is run with newly initialized Variables instead of ones restored from checkpoint.
	
	yield_single_examples: If False, yields the whole batch as returned by the model_fn instead of decomposing the batch into individual elements. This is useful if model_fn returns some tensors whose first dimension is not equal to the batch size.
'''

说明

给定输入,返回在model_fn中指定要输出的内容tf.estimator.EstimatorSpec(mode,predictions=pred_classes)

    ....
    ....
    
    pred_classes=tf.argmax(logits,axis=1)
    pred_probas=tf.nn.softmax(logits)
    
    #PREDICTS
    if mode==tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode,predictions=pred_classes)
    .....
    ......
        

具体参数和trian 方法的参数基本相同,就不多说,这里重点讲一下下面几个:
predict_keys: 是一个str类型的list,如果使用这个predict_keys,那么模型只会返回predictions 中和predict_keys相同的key的值
**checkpoint_path:**要预测的特定检查点的路径。如果没有,则使用model_dir中的最新检查点。如果在model_dir中没有检查点,则使用新初始化的变量而不是从检查点恢复的变量运行预测
yield_single_examples: 如果为False,则生成model_fn返回的整个批,而不是将批分解为单个元素。如果model_fn返回其第一维不等于批处理大小的一些张量,则这很有用。

猜你喜欢

转载自blog.csdn.net/qq_32806793/article/details/85019394
今日推荐