evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)
effect
Validate the model using the validation set input_fn.
For each step, execute input_fn (returns a batch of the dataset).
- batches have been
steps
run , or input_fn
An out-of-bounds exception was thrown (OutOfRangeError
orStopIteration
)
parameter
input_fn
: This function constructs the input data required for validation and needs to return one of the following structures:
- an
tf.data.Dataset
object :Dataset
The output of the object must be a tuple (features, labels) with the same specifications as below. - A tuple (features, labels):
features
is aTensor
or a dictionary (a dictionary of string feature name toTensor
).labels
is aTensor
or a dictionary of string label name toTensor
.features
andlabels
are both usedmodel_fn
by (model_fn
is one of the parameters oftf.estimator.Estimator
the constructor of ). They should meet the needs of themodel_fn
input side.
steps
: Number of steps to validate the model. If None
so, keep verifying until input_fn
an out-of-bounds exception is thrown.
hooks
: SessionRunHook
A list of subclass instances. As a callback function for validation.
checkpoint_path
: Path to a specific checkpoint. If it is None
, it defaults to model_dir
the closest checkpoint in ( model_dir
is one of the arguments tf.estimator.Estimator
to the constructor of )
name
: Authentication name. Users can run multiple validation operations on different datasets, such as train vs test. The results of different validations are saved in different folders and appear in tensorboard respectively.
return value
Returns a dictionary model_fn
containing the evaluation metrics specified in , global_step
.
exception thrown
ValueError
: if step
less than or equal to 0
ValueError
: if the model_dir
specified model has not been trained, or if the specified checkpoint_path
is empty.
Example
First define Estimator
:
cnn_model = tf.estimator.Estimator(
model_fn=model_function, model_dir=save_model_path
)
Then train:
cnn_model.train(
input_fn=lambda: get_train_batch(train_file_path), steps=steps_per_eval)
Finally verify:
evaluate_results = cnn_model.evaluate(
input_fn=lambda: get_val_batch(val_file_path),
steps=eval_steps_per_train_cycle)
where the data is read from tfrecords:
def get_train_batch(data_dir, batch_size=conf.batch_size, set_name='train', use_distortion=True):
dataset = DataSet(data_dir, set_name, use_distortion)
return dataset.get_batch(data_dir, batch_size)
def get_val_batch(data_dir, batch_size=conf.batch_size, set_name='val', use_distortion=False):
dataset = DataSet(data_dir, set_name, use_distortion)
return dataset.get_batch(data_dir, batch_size)
class DataSet(object):
....
def get_batch(self, file_path, batch_size):
"""
:param batch_size: train, val, test batch_size is different
:param file_path:
:return:
"""
files = tf.data.Dataset.list_files(file_path)
dataset = files.apply(
tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=conf.num_parallel_readers,
sloppy=True))
if self.set_name == 'train':
dataset = dataset.repeat(conf.train_epochs)
dataset = dataset.shuffle(conf.shuffle_buffer_size)
dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=self.parser_single_img, batch_size=batch_size,
num_parallel_batches=conf.num_parallel_batches))
dataset = dataset.prefetch(conf.batch_size)
iterator = dataset.make_one_shot_iterator()
img_batch, label_batch = iterator.get_next()
return img_batch, label_batch