Tensorflow静态图pb(frozen graph)模型保存与调用

pb模型保存

基于tf2

model = ...

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)

基于keras (tf1)

from tensorflow.keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""

        frozen_graph = graph_util.convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names)
        if not clear_devices:
            for node in frozen_graph.node:
                node.device = "/GPU:0"
        return frozen_graph


# load model
model = keras.models.model_from_json(...)


# save pb model
out_path = 'model.pb'
input_names = [n.op.name for n in model.inputs]
output_names = [n.op.name for n in model.outputs]
print(input_names, output_names)
frozen_graph = freeze_session(K.get_session(), output_names=output_names,clear_devices=clear_devices)
with open(out_path, "wb") as f:
    f.write(frozen_graph.SerializeToString())

模型调用

这里以tf1为例:

from tensorflow.compat.v1 import Graph, GraphDef, import_graph_def, Session
from tensorflow.compat.v1.gfile import GFile

frozen_graph =  "model.pb"
# import graph
with GFile(frozen_graph, "rb") as f:
    graph_def = GraphDef()
    graph_def.ParseFromString(f.read())
with Graph().as_default() as graph:
    import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=""
                     )

# set input output
x = graph.get_tensor_by_name("input:0")
y1 = graph.get_tensor_by_name("output1:0")
y2 = graph.get_tensor_by_name("output1:0")
sess = Session(graph=graph)

# get batch_input
batch_image = np.zeros([1, 512, 512, 3])
# get ...

# predict
feed_dict_testing = {x: batch_image}
output1, output2 = sess.run([y1, y2], feed_dict=feed_dict_testing)

猜你喜欢

转载自blog.csdn.net/dou3516/article/details/110871797