Deep learning和tensorflow学习记录(二十六):tf.estimator.train_and_evaluate

tf.estimator.train_and_evaluate(
    estimator,
    train_spec,
    eval_spec
)

定义于:tensorflow/python/estimator/training.py

训练和评估estimator

该通用函数使用给定的estimator训练,评估和(可选地)导出模型。所有训练相关的规范都包含在train_spec内,包括训练input_fn和最大训练次数等。所有评估和导出相关的规范都包含在eval_spec内,包括评估input_fn,次数等。

此通用函数本地(非分布式)和分布式配置是一致的。目前,唯一支持的分布式培训配置是图间复制。

Overfitting:为了避免Overfitting,建议设置训练input_fn以适当地改变训练数据。在进行评估之前,还建议将模型多训练一段时间,比如多个epochs,因为输入管道从头开始进行每次训练。这对本地的训练和评估尤为重要。

Stop condition:为了可靠地支持分布式和非分布式配置,模型训练唯一支持的Stop condition是train_spec.max_steps。如果train_spec.max_stepsNone,模型将永远训练下去。如果模型Stop condition不同,请小心使用。例如,假设预期模型将使用一个epoch训练数据进行训练,并且训练input_fn被配置为 在经过一个epoch训练之后抛出OutOfRangeError,停止训练Estimator.train。对于 three-training-worker分布式配置,每个training worker可能在整个epoch独立地完成训练。因此,该模型将使用三个epoches训练数据而不是一个epoch进行训练。

本地(非分布式)训练示例:

# Set up feature columns.
categorial_feature_a = categorial_column_with_hash_bucket(...)
categorial_feature_a_emb = embedding_column(
    categorical_column=categorial_feature_a, ...)
...  # other feature columns

estimator = DNNClassifier(
    feature_columns=[categorial_feature_a_emb, ...],
    hidden_units=[1024, 512, 256])

# Or set up the model directory
#   estimator = DNNClassifier(
#       config=tf.estimator.RunConfig(
#           model_dir='/my_model', save_summary_steps=100),
#       feature_columns=[categorial_feature_a_emb, ...],
#       hidden_units=[1024, 512, 256])

# Input pipeline for train and evaluate.
def train_input_fn: # returns x, y
  # please shuffle the data.
  pass
def eval_input_fn_eval: # returns x, y
  pass

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

分布式训练示例:

关于分布式训练的示例,上面的代码可以在没有更改的情况下使用(请确保将RunConfig.model_dir所有workers 设置为相同的目录,即所有workers 都可以读写的共享文件系统)。唯一需要做的额外工作是相应地为每个workers 正确设置环境变量TF_CONFIG

另请参阅 Distributed TensorFlow

设置环境变量取决于平台。例如,在Linux上,它可以按如下方式完成($是shell提示符):

$ TF_CONFIG='<replace_with_real_content>' python train_model.py

对于内容TF_CONFIG,假设训练集群规范如下:

cluster = {"chief": ["host0:2222"],
           "worker": ["host1:2222", "host2:2222", "host3:2222"],
           "ps": ["host4:2222", "host5:2222"]}

TF_CONFIG主要训练workers 的例子(必须有一个且只有一个):

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster": {
        "chief": ["host0:2222"],
        "worker": ["host1:2222", "host2:2222", "host3:2222"],
        "ps": ["host4:2222", "host5:2222"]
    },
    "task": {"type": "chief", "index": 0}
}'

请注意,主要workers 也进行模型训练工作,类似于其他非主要训练workers (见下一段)。除了模型训练之外,它还管理一些额外的工作,例如检查点保存和恢复,写入summaries等。

TF_CONFIG非主要训练workers 的示例(可选,可以是多个):

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster": {
        "chief": ["host0:2222"],
        "worker": ["host1:2222", "host2:2222", "host3:2222"],
        "ps": ["host4:2222", "host5:2222"]
    },
    "task": {"type": "worker", "index": 0}
}'

其中task.index应分别设定为0,1,2,在这个例子中,用于非主要训练workers 。

TF_CONFIG参数服务器的示例,也就是ps(可能是多个):

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster": {
        "chief": ["host0:2222"],
        "worker": ["host1:2222", "host2:2222", "host3:2222"],
        "ps": ["host4:2222", "host5:2222"]
    },
    "task": {"type": "ps", "index": 0}
}'

其中task.index应分别设置为0和1,在本例中,分别为参数服务器。

TF_CONFIG评估任务的示例。Evaluator是一项特殊任务,不属于训练集群。可能只有一个,它用于模型评估。

# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
    "cluster": {
        "chief": ["host0:2222"],
        "worker": ["host1:2222", "host2:2222", "host3:2222"],
        "ps": ["host4:2222", "host5:2222"]
    },
    "task": {"type": "evaluator", "index": 0}
}'

参数:

  • estimatorEstimator训练和评估的实例。

  • train_specTrainSpec指定训练规范的实例。

  • eval_specEvalSpec指定评估和导出规范的实例。

返回:

evaluate的结果元组,和指定ExportStrategy导出的结果。目前,分布式训练模式的返回值未定义。

 

 

猜你喜欢

转载自blog.csdn.net/heiheiya/article/details/81076409