Tensorflow API 讲解——tf.estimator.Estimator.evaluate

evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)

作用

使用验证集 input_fn 对 model 进行验证。
对于每一步,执行 input_fn(返回数据集的一个 batch)。

  • 已经进行了 steps 个 batch,或者
  • input_fn 抛出了出界异常(OutOfRangeErrorStopIteration

参数

input_fn:此函数构造出验证所需的输入数据,需要返回以下结构之一:

  • 一个 tf.data.Dataset 对象:Dataset 对象的输出必须是一个元组 (features, labels),和下面的规格相同。
  • 一个元组 (features, labels):features 是一个 Tensor 或者字典(a dictionary of string feature name to Tensor )。labels 是一个 Tensor 或者字典(a dictionary of string label name to Tensor)。featureslabels 都被 model_fn 所使用(model_fntf.estimator.Estimator 的构造函数的参数之一)。他们应该满足 model_fn 输入端的需求。

steps:验证模型的步数。如果是 None,则一直验证下去,直至input_fn 抛出了出界异常。

hooksSessionRunHook子类实例的 list。作为验证的回调函数。

checkpoint_path:特定检查点的路径。如果是 None,则默认为 model_dir 中最近的检查点(model_dirtf.estimator.Estimator 的构造函数的参数之一)

name:验证的名字。使用者可以针对不同的数据集运行多个验证操作,比如训练集 vs 测试集。不同验证的结果被保存在不同的文件夹中,且分别出现在 tensorboard 中。

返回值

返回一个字典,包括 model_fn 中指定的评价指标、global_step

异常抛出

ValueError:如果 step 小于等于0
ValueError:如果 model_dir 指定的模型没有被训练,或者指定的 checkpoint_path 为空。

示例

先定义Estimator

cnn_model = tf.estimator.Estimator(
    model_fn=model_function, model_dir=save_model_path
)

然后进行训练:

cnn_model.train(
    input_fn=lambda: get_train_batch(train_file_path), steps=steps_per_eval)

最后进行验证:

evaluate_results = cnn_model.evaluate(
    input_fn=lambda: get_val_batch(val_file_path),
    steps=eval_steps_per_train_cycle)

其中,数据是从 tfrecords 中读取的:

def get_train_batch(data_dir, batch_size=conf.batch_size, set_name='train', use_distortion=True):
    dataset = DataSet(data_dir, set_name, use_distortion)
    return dataset.get_batch(data_dir, batch_size)


def get_val_batch(data_dir, batch_size=conf.batch_size, set_name='val', use_distortion=False):
    dataset = DataSet(data_dir, set_name, use_distortion)
    return dataset.get_batch(data_dir, batch_size)

class DataSet(object):
    ....

    def get_batch(self, file_path, batch_size):
        """

        :param batch_size: train, val, test batch_size is different
        :param file_path:
        :return:
        """
        files = tf.data.Dataset.list_files(file_path)
        dataset = files.apply(
            tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=conf.num_parallel_readers,
                                                sloppy=True))
        if self.set_name == 'train':
            dataset = dataset.repeat(conf.train_epochs)
            dataset = dataset.shuffle(conf.shuffle_buffer_size)

        dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=self.parser_single_img, batch_size=batch_size,
                                                              num_parallel_batches=conf.num_parallel_batches))
        dataset = dataset.prefetch(conf.batch_size)
        iterator = dataset.make_one_shot_iterator()
        img_batch, label_batch = iterator.get_next()
        return img_batch, label_batch

猜你喜欢

转载自blog.csdn.net/HappyRocking/article/details/80229508