tensorflow保存模型的几种方法

tensorflow保存模型有多种方法

第一种:saver.save(sess, "./hello_model") # 生成ckpt模型文件, hello_model.data-00000-of-00001  hello_model.index  hello_model.meta

第二种:tf.train.write_graph(sess.graph_def, ./,  'hello.pb') # 生成hello.pb, 再通过freeze_graph把hello.pb与ckpt固化成新的pb文件

第三种:用tf.graph_util.convert_variables_to_constants把变量转成常量之后写入PB文件中

第四种:使用tf.saved_model.builder.SavedModelBuilder

具体看代码,及示例

保存模型的文件 saver_hello.py

import tensorflow as tf
import sys
import os

# 把变量转成常量之后写入PB文件中
def SaveFrozenPb(nodeNameList, pbFile):
    gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), nodeNameList)
    with tf.gfile.GFile(pbFile, 'wb') as f:
        f.write(gd.SerializeToString())

# 通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件
# freeze_graph --input_graph=./hello.pb --input_checkpoint=./hello_model --output_node_names=hello,y --input_node_names=x --output_graph=./hello_frozen.pb
# 如果不调用freeze_graph, 直接使用会报错‘google.protobuf.message.DecodeError: Error parsing message’
def SavePbForFreezeGraph(pbDir, pbName):
    tf.train.write_graph(sess.graph_def, pbDir, pbName)

def SaveBuilderPb(pbDir):
    builder = tf.saved_model.builder.SavedModelBuilder(pbDir)
    builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None)
    builder.save()

if __name__ == '__main__':
    hello = tf.Variable(tf.constant('Hello World', name = "hello")) # 要save成功,需要tf.Variable, 否则会报错'ValueError: No variables to save'
    x = tf.placeholder(tf.float32, name="x")
    y = tf.multiply(x, 2, name="y")

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    saver = tf.train.Saver()

    typeStr = sys.argv[1]
    if typeStr == 'ckpt' or typeStr == 'pbNotFrozen':
        saver.save(sess, "./hello_model", write_meta_graph=True) # hello_model.data-00000-of-00001  hello_model.index  hello_model.meta
    if typeStr == 'pbNotFrozen':
        SavePbForFreezeGraph('./', 'hello.pb') # 需要经由freeze_graph工具处理
    elif typeStr == 'pbFrozen':
        SaveFrozenPb(['x', 'y', 'hello'], './hello_frozen.pb') # 无需再经由freeze_graph工具处理
    elif typeStr == 'builderPb':
        SaveBuilderPb('./save/')

加载模型文件restore_hello.py

import tensorflow as tf
import sys

def RestoreMeta(sess, name):
    #ckpt = tf.train.get_checkpoint_state('./')
    #restore = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
    #restore.restore(sess, ckpt.model_checkpoint_path)
    restore = tf.train.import_meta_graph(name)
    restore.restore(sess, "hello_model")

def RestorePb(sess, name):
    # 二进制读取模型文件
    with tf.gfile.FastGFile(name, 'rb') as f:                     
       graph_def = tf.GraphDef()                     
       graph_def.ParseFromString(f.read())                     
       sess.graph.as_default()                     
       tf.import_graph_def(graph_def, name='') # 导入计算图 

def RestoreBuilderPb(sess, pbDir):
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], pbDir)

if __name__ == '__main__':
    sess = tf.Session()
    typeStr = sys.argv[1]
    if typeStr == 'ckpt':
        RestoreMeta(sess, 'hello_model.meta')
    elif typeStr == 'pbFrozen':
        RestorePb(sess, './hello_frozen.pb')
    elif typeStr == 'builderPb':
        RestoreBuilderPb(sess, './save/')

    x = tf.get_default_graph().get_tensor_by_name("x:0")
    y = tf.get_default_graph().get_tensor_by_name("y:0")
    hello  = tf.get_default_graph().get_tensor_by_name("hello:0")

    print(sess.run(y, feed_dict={x:5})) # 10.0
    print(sess.run(hello)) # b'Hello World'

第一种:ckpt

保存模型

python3 ./saver_hello.py ckpt

生成checkpoint hello_model.data-00000-of-00001  hello_model.index  hello_model.meta
加载模型

python3 ./restore_hello.py ckpt

运行结果

10.0
b'Hello World'
 

第二种:ckpt+pb+固化 

python3 ./saver_hello.py pbNotFrozen
生成checkpoint   hello_model.data-00000-of-00001  hello_model.index  hello_model.meta  hello.pb

固化

freeze_graph --input_graph=./hello.pb --input_checkpoint=./hello_model --output_node_names=hello,y --input_node_names=x --output_graph=./hello_frozen.pb

加载

python3 ./restore_hello.py pbFrozen
 

第三种:固化的pb

保存

python3 ./saver_hello.py pbFrozen

加载

python3 ./restore_hello.py pbFrozen

第四种:

python3 ./saver_hello.py builderPb
python3 ./restore_hello.py builderPb

发布了201 篇原创文章 · 获赞 20 · 访问量 41万+

猜你喜欢

转载自blog.csdn.net/zmlovelx/article/details/100176865
今日推荐