Tensorflow API 讲解——tf.estimator.Estimator.evaluate

evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)

effect

Validate the model using the validation set input_fn.
For each step, execute input_fn (returns a batch of the dataset).

  • batches have been stepsrun , or
  • input_fnAn out-of-bounds exception was thrown ( OutOfRangeErroror StopIteration)

parameter

input_fn: This function constructs the input data required for validation and needs to return one of the following structures:

  • an tf.data.Datasetobject : DatasetThe output of the object must be a tuple (features, labels) with the same specifications as below.
  • A tuple (features, labels): featuresis a Tensoror a dictionary (a dictionary of string feature name to Tensor). labelsis a Tensoror a dictionary of string label name to Tensor. featuresand labelsare both used model_fnby ( model_fnis one of the parameters of tf.estimator.Estimatorthe constructor of ). They should meet the needs of the model_fninput side.

steps: Number of steps to validate the model. If Noneso, keep verifying until input_fnan out-of-bounds exception is thrown.

hooks: SessionRunHookA list of subclass instances. As a callback function for validation.

checkpoint_path: Path to a specific checkpoint. If it is None, it defaults to model_dirthe closest checkpoint in ( model_diris one of the arguments tf.estimator.Estimatorto the constructor of )

name: Authentication name. Users can run multiple validation operations on different datasets, such as train vs test. The results of different validations are saved in different folders and appear in tensorboard respectively.

return value

Returns a dictionary model_fncontaining the evaluation metrics specified in , global_step.

exception thrown

ValueError: if stepless than or equal to 0
ValueError: if the model_dirspecified model has not been trained, or if the specified checkpoint_pathis empty.

Example

First define Estimator:

cnn_model = tf.estimator.Estimator(
    model_fn=model_function, model_dir=save_model_path
)

Then train:

cnn_model.train(
    input_fn=lambda: get_train_batch(train_file_path), steps=steps_per_eval)

Finally verify:

evaluate_results = cnn_model.evaluate(
    input_fn=lambda: get_val_batch(val_file_path),
    steps=eval_steps_per_train_cycle)

where the data is read from tfrecords:

def get_train_batch(data_dir, batch_size=conf.batch_size, set_name='train', use_distortion=True):
    dataset = DataSet(data_dir, set_name, use_distortion)
    return dataset.get_batch(data_dir, batch_size)


def get_val_batch(data_dir, batch_size=conf.batch_size, set_name='val', use_distortion=False):
    dataset = DataSet(data_dir, set_name, use_distortion)
    return dataset.get_batch(data_dir, batch_size)

class DataSet(object):
    ....

    def get_batch(self, file_path, batch_size):
        """

        :param batch_size: train, val, test batch_size is different
        :param file_path:
        :return:
        """
        files = tf.data.Dataset.list_files(file_path)
        dataset = files.apply(
            tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=conf.num_parallel_readers,
                                                sloppy=True))
        if self.set_name == 'train':
            dataset = dataset.repeat(conf.train_epochs)
            dataset = dataset.shuffle(conf.shuffle_buffer_size)

        dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=self.parser_single_img, batch_size=batch_size,
                                                              num_parallel_batches=conf.num_parallel_batches))
        dataset = dataset.prefetch(conf.batch_size)
        iterator = dataset.make_one_shot_iterator()
        img_batch, label_batch = iterator.get_next()
        return img_batch, label_batch

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325964066&siteId=291194637