本文介绍了 Estimators 模型的保存和恢复。
TensorFlow提供了两种模型格式:
- checkpoints:这种格式依赖于创建模型的代码。
- SavedModel:这种格式与创建模型的代码无关。
本文档主要介绍checkpoints。
1. 保存经过部分训练的模型
Estimators 在训练过程中会自动将以下内容保存到磁盘:
- chenkpoints:训练过程中的模型快照。
- event files:其中包含 TensorBoard 用于创建可视化图表的信息。
通过 model_dir 参数,我们可以指定 Estimator 保存上述文件时的顶级目录。
# 实例化 estimator
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
# 训练 estimator
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
如下图所示,第一次调用 train
方法会将 checkpoints 和 event files 文件添加到 model_dir
目录中。
来查看 model_dir
目录中的内容:
我们可以看到,Estimator在step 1(训练开始)和step 12130(训练结束)创建了 checkpoints 文件。
1.2 创建 Checkpoints 的频率
默认情况下,Estimator 会根据以下策略来写入 checkpoints。
- 每10分钟(600秒)向磁盘写入一个 checkpoint。
- 在 train 方法开始(第一次迭代)和结束(最后一次迭代)时写入一个 checkpoint。
- model_dir 目录中保留 5 个最近写入的检查点。
当然,你可以按如下方式修改 checkpoint 的写入策略:
- 创建一个tf.estimator.RunConfig对象来定义 checkpoint 写入策略。
- 在实例化 Estimator 时,将 RunConfig 对象传给 Estimator 的 config 参数。
下面的代码将 checkpoint 写入间隔设置为20分钟,并且保留最近的10个 checkpoints:
est_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # 每20分钟保存一次 checkpoints
keep_checkpoint_max = 10, # 保留最新的10个checkpoints
)
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris',
config=est_config)
3. 恢复模型
第一次调用 Estimator 的 train 方法时,TensorFlow会保存 checkpoint 文件到 model_dir 目录。随后调用 tarin、evaluate、predict 方法将进行如下操作:
- Estimator 通过运行 model_fn 来构建模型的计算图。
- Estimator 从 checkpoints 中初始化模型参数。
2.1 避免不当恢复
仅在模型和checkpoint兼容的情况下,才能从 checkpoint 恢复模型的状态。例如,假设您训练了DNNClassifier包含两个隐藏层的 Estimator,每个隐藏层有10个节点:
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
训练之后,如果您将每个隐藏层中的神经元数量从10更改为20,并尝试重新训练模型:
classifier2 = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[20, 20], # Change the number of neurons in the model.
n_classes=3,
model_dir='models/iris')
classifier2.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
由于 checkpoint 中的状态与描述的模型不兼容,因此重新训练失败并出现以下错误:
...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]
如果保存报错,粗暴的方式是:将原来保存的模型文件删除掉,新文件下保存模型不会出错。
参考:
tensorflow中模型的保存与使用总结 — carlos9310
https://blog.csdn.net/u014061630/article/details/82901646