网络结构保存为pb(二)

第一种保存方式

import tensorflow as tf
from tensorflow.python.framework import graph_util
logdir='./checkpoints/'
with tf.variable_scope('conv'):
    w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
    b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
sess=tf.InteractiveSession()
tf.global_variables_initializer().run() # 初始化所有变量
# tf.train.write_graph(tf.get_default_graph(),logdir,'expert-graph.pb',False)
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["conv/w"])
with tf.gfile.FastGFile(logdir+'expert-graph.pb', mode='wb') as f:
    f.write(constant_graph.SerializeToString())
sess.close()

下面这种有占位符的保存形式

 
 
import tensorflow as tf
from tensorflow.python.framework import graph_util

logdir='./checkpoints/'#模型保存路径为当前路径下的checkpoints文件夹下
# logdir='./'#模型保存路径为当前路径
# 
# 
#     v = tf.get_variable("v",shape=[1],dtype=tf.int32,initializer=tf.zeros_initializer())
# 
# with tf.Session(graph=g1) as sess:
#     tf.global_variables_initializer().run()
# with tf.variable_scope("",reuse=True):
#     print(sess.run(tf.get_variable("v",dtype=tf.int32)))
# with tf.variable_scope('conv'):
g1 = tf.Graph()
with g1.as_default():
    x = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[None, 3])
    w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
    b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
    
    with tf.Session(graph=g1) as sess:
        tf.global_variables_initializer().run()
#     sess=tf.InteractiveSession()    
#     tf.global_variables_initializer().run() # 初始化所有变量    
    # tf.train.write_graph(tf.get_default_graph(),logdir,'expert-graph.pb',False)    
#         constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["w","b"])
        constant_graph = tf.get_default_graph().as_graph_def()#这种
    #         constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["w","b"])
        with tf.gfile.FastGFile(logdir+'expert-graph.pb', mode='wb') as f:
            f.write(constant_graph.SerializeToString())     
        sess.close()


pb的读取

import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
logdir='./checkpoints/'

output_graph_path = logdir+'expert-graph.pb'
with tf.Session() as sess:   
    tf.global_variables_initializer().run()
    output_graph_def = tf.GraphDef()
    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")

    input_x = sess.graph.get_tensor_by_name("input:0")    
    print(input_x)
    y=input_x
    out=sess.run(y,{input_x:np.random.random([1,3])})
    print(out[:10])#取出out的前10行,
    w = sess.graph.get_tensor_by_name("w:0")
    print(w)



猜你喜欢

转载自blog.csdn.net/weixin_38145317/article/details/79522654