TensorFlow 安装及使用

安装

使用

模型优化

(1)查看 saved_model 模型的输入和输出

# bazel build tensorflow/python/tools:saved_model_cli
# saved_model_cli show --dir detection/ --all
或者
# python3 /<path>/tensorflow/tensorflow/python/tools/saved_model_cli.py show --dir detection/ --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['image'] tensor_info:
        dtype: DT_UINT8
        shape: (1, -1, -1, 3)
        name: image:0
    inputs['true_image_shape'] tensor_info:
        dtype: DT_INT32
        shape: (1, 3)
        name: true_image_shape:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['detection_boxes'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, -1, 4)
        name: ChangeCoordToOriginalImage/stack:0
    outputs['detection_classes'] tensor_info:
        dtype: DT_INT32
        shape: (1, -1)
        name: add:0
    outputs['detection_keypoints'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, -1, 4, 2)
        name: TextKeypointPostProcess/Reshape_2:0
    outputs['detection_scores'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, -1)
        name: strided_slice_3:0
    outputs['num_detections'] tensor_info:
        dtype: DT_INT32
        shape: (1)
        name: BatchMultiClassNonMaxSuppression/stack_8:0
  Method name is: tensorflow/serving/predict

(2)将 tf 的 saved_model 保存成 frozen_model

# bazel build tensorflow/python/tools:freeze_graph
# freeze_graph --input_saved_model_dir detection/ --output_graph detection_frozen_model.pb --output_node_names ChangeCoordToOriginalImage/stack,add,TextKeypointPostProcess/Reshape_2,strided_slice_3,BatchMultiClassNonMaxSuppression/stack_8
或者
# python3 /<path>/tensorflow/tensorflow/python/tools/freeze_graph.py --input_saved_model_dir detection/ --output_graph detection_frozen_model.pb --output_node_names ChangeCoordToOriginalImage/stack,add,TextKeypointPostProcess/Reshape_2,strided_slice_3,BatchMultiClassNonMaxSuppression/stack_8

(3)将 frozen_model 通过优化得到 optimized_model

# bazel build tensorflow/python/tools:optimize_for_inference   // ouput: bazel-bin/tensorflow/python/tools/optimize_for_inference
# optimize_for_inference --input detection_frozen_model.pb --output detection_optimized_model.pb --input_names image,true_image_shape --output_names ChangeCoordToOriginalImage/stack,add,TextKeypointPostProcess/Reshape_2,strided_slice_3,BatchMultiClassNonMaxSuppression/stack_8 --frozen_graph true --placeholder_type_enum 4,3,1,3,1,1,3
或者
# python3 /<path>/tensorflow/tensorflow/python/tools/optimize_for_inference.py --input detection_frozen_model.pb --output detection_optimized_model.pb --input_names image,true_image_shape --output_names ChangeCoordToOriginalImage/stack,add,TextKeypointPostProcess/Reshape_2,strided_slice_3,BatchMultiClassNonMaxSuppression/stack_8 --frozen_graph true --placeholder_type_enum 4,3,1,3,1,1,3

其中 placeholder_type_enum 详情如下:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto

(4)将 pb 模型输出成 TensorFlow 的可视化 graph

# bazel build tensorflow/python/tools:import_pb_to_tensorboard
# import_pb_to_tensorboard --model_dir ./recognition_frozen_model.pb --log_dir ./recognition_log
或者
# python3 /<path>/tensorflow/tensorflow/python/tools/import_pb_to_tensorboard.py --model_dir ./recognition_frozen_model.pb --log_dir ./recognition_frozen_model.graph
# nohup tensorboard --logdir=./recognition_frozen_model.graph --port=6006 2>&1 &

可视化工具 TensorBoard 用法: https://blog.csdn.net/gg_18826075157/article/details/78440766

(5)量化、固化、优化 pb 模型

# python3 transform_graph --in_graph="./detection_frozen_model.pb" --out_graph="./detection_transformed_model.pb" --inputs="image,true_image_shape" --outputs="ChangeCoordToOriginalImage/stack,add,TextKeypointPostProcess/Reshape_2,strided_slice_3,BatchMultiClassNonMaxSuppression/stack_8" --transforms='
  add_default_attributes
  strip_unused_nodes()
  remove_nodes(op=Identity, op=CheckNumerics)
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms
  quantize_weights'

猜你喜欢

转载自www.cnblogs.com/qccz123456/p/11504660.html