参考:https://zhuanlan.zhihu.com/p/34471266
简介
TensorFlow的模型格式有很多种,针对不同场景可以使用不同的格式,只要符合规范的模型都可以轻易部署到在线服务或移动设备上,这里简单列举一下。
- Checkpoint: 用于保存模型的权重,主要用于模型训练过程中参数的备份和模型训练热启动。
- GraphDef:用于保存模型的Graph,不包含模型权重,加上checkpoint后就有模型上线的全部信息。
- ExportModel:使用exportor接口导出的模型文件,包含模型Graph和权重可直接用于上线,但官方已经标记为deprecated推荐使用SavedModel。
- SavedModel:使用saved_model接口导出的模型文件,包含模型Graph和权限可直接用于上线,TensorFlow和Keras模型推荐使用这种模型格式。
- FrozenGraph:使用freeze_graph.py对checkpoint和GraphDef进行整合和优化,可以直接部署到Android、iOS等移动设备上。
- TFLite:基于flatbuf对模型进行优化,可以直接部署到Android、iOS等移动设备上,使用接口和FrozenGraph有些差异。
模型格式
目前建议TensorFlow和Keras模型都导出成SavedModel格式,这样就可以直接使用通用的TensorFlow Serving服务,模型导出即可上线不需要改任何代码。不同的模型导出时只要指定输入和输出的signature即可,其中字符串的key可以任意命名只会在客户端请求时用到,可以参考下面的代码示例。
注意,目前使用tf.py_func()的模型导出后不能直接上线,模型的所有结构建议都用op实现。
TensorFlow模型导出
import os
import tensorflow as tf
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import (
signature_constants, signature_def_utils, tag_constants, utils)
from tensorflow.python.util import compat
model_path = "model"
model_version = 1
model_signature = signature_def_utils.build_signature_def(
inputs={
"keys": utils.build_tensor_info(keys_placeholder),
"features": utils.build_tensor_info(inference_features)
},
outputs={
"keys": utils.build_tensor_info(keys_identity),
"prediction": utils.build_tensor_info(inference_op),
"softmax": utils.build_tensor_info(inference_softmax),
},
method_name=signature_constants.PREDICT_METHOD_NAME)
export_path = os.path.join(compat.as_bytes(model_path), compat.as_bytes(str(model_version)))
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
clear_devices=True,
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
model_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
Keras模型导出
import os
import tensorflow as tf
from tensorflow.python.util import compat
def export_savedmodel(model):
model_path = "model"
model_version = 1
model_signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'input': model.input}, outputs={'output': model.output})
export_path = os.path.join(compat.as_bytes(model_path), compat.as_bytes(str(model_version)))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
clear_devices=True,
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
model_signature
})
builder.save()
SavedModel模型结构
使用TensorFlow的API导出SavedModel模型后,可以检查模型的目录结构如下,然后就可以直接使用开源工具来加载服务了。
模型上线
部署在线服务
使用HTTP接口可参考 tobegit3hub/simple_tensorflow_serving 。
使用gRPC接口可参考 tensorflow/serving 。
部署离线设备
部署到Android可参考 https://medium.com/@tobe_ml/all-tensorflow-models-can-be-embedded-into-mobile-devices-1932e80579e5 。
部署到iOS可参考 https://zhuanlan.zhihu.com/p/33715219 。
待更新