tensorflow对已经训练的模型进行优化和固化

前言

有时,我们需要保存tensorflow训练的模型:

  • tf.train.write_graph()默认情况下只导出了网络的定义(没有权重)
  • 利用tf.train.Saver().save()导出的文件graph_def与权重是分离的 为了方便使用模型,通过tensorflow.python.tools.freeze_graph可以将两者进行合并和优化最后得到最终的PB文件。

1.通过ckpt和tf.train.write_graph得到基础pb文件(无权重)

1.1在训练过程中使用使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件:
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.save(session, "model.ckpt")
    tf.train.write_graph(session.graph_def, '', 'graph.pb')
1.2非训练过程中使用(加载网络生成pb文件):

参考项目run_checkpoint.py

import argparse
import logging
import tensorflow as tf
from tf_pose.networks import get_network
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
config = tf.ConfigProto()
config.gpu_options.allocator_type = 'BFC'
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.allow_growth = True
if __name__ == '__main__':
    """
    Use this script to just save graph and checkpoint.
    While training, checkpoints are saved. You can test them with this python code.
    """
    parser = argparse.ArgumentParser(description='Tensorflow Pose Estimation Graph Extractor')
    parser.add_argument('--model', type=str, default='cmu', help='cmu / mobilenet_thin / my_mobilenet_thin')
    args = parser.parse_args()

    input_node = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='image')

    with tf.Session(config=config) as sess:
        #网络结构以及ckpt参数
        net, _, last_layer = get_network(args.model, input_node, sess, trainable=False)
        #net:网络结构
        #_ :ckpt参数
        #last_layer:网络最后一层名称
        print(last_layer)
        tf.train.write_graph(sess.graph_def, 'models/graph/my_mobilenet_thin/', 'graph.pb', as_text=True)

关于函数get_network:

def get_network(type, placeholder_input, sess_for_load=None, trainable=True):
    if type == 'mobilenet':
        net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.75, conv_width2=1.00, trainable=trainable)
        pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt'
        last_layer = 'MConv_Stage6_L{aux}_5'
    return net, pretrain_path_full, last_layer

2.得到的pb文件与ckpt进行freezing:

$ python3 -m tensorflow.python.tools.freeze_graph \
  --input_graph=... \
  --output_graph=... \
  --input_checkpoint=... \
  --output_node_names="Openpose/concat_stage7"

参考资料:

  1. 将TensorFlow的网络导出为单个文件
  2. tensorflow-Freezing
  3. Model Optimization for Inference
  4. TensorFlow固化模型

猜你喜欢

转载自blog.csdn.net/m0_37477175/article/details/81187929