第一种保存方式
import tensorflow as tf from tensorflow.python.framework import graph_util logdir='./checkpoints/' with tf.variable_scope('conv'): w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer) b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer) sess=tf.InteractiveSession() tf.global_variables_initializer().run() # 初始化所有变量 # tf.train.write_graph(tf.get_default_graph(),logdir,'expert-graph.pb',False) constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["conv/w"]) with tf.gfile.FastGFile(logdir+'expert-graph.pb', mode='wb') as f: f.write(constant_graph.SerializeToString()) sess.close()
下面这种有占位符的保存形式
import tensorflow as tf from tensorflow.python.framework import graph_util logdir='./checkpoints/'#模型保存路径为当前路径下的checkpoints文件夹下 # logdir='./'#模型保存路径为当前路径 # # # v = tf.get_variable("v",shape=[1],dtype=tf.int32,initializer=tf.zeros_initializer()) # # with tf.Session(graph=g1) as sess: # tf.global_variables_initializer().run() # with tf.variable_scope("",reuse=True): # print(sess.run(tf.get_variable("v",dtype=tf.int32))) # with tf.variable_scope('conv'): g1 = tf.Graph() with g1.as_default(): x = tf.placeholder(name='input', dtype=tf.float32, shape=[None, 3]) w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer) b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer) with tf.Session(graph=g1) as sess: tf.global_variables_initializer().run() # sess=tf.InteractiveSession() # tf.global_variables_initializer().run() # 初始化所有变量 # tf.train.write_graph(tf.get_default_graph(),logdir,'expert-graph.pb',False) # constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["w","b"]) constant_graph = tf.get_default_graph().as_graph_def()#这种 # constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["w","b"]) with tf.gfile.FastGFile(logdir+'expert-graph.pb', mode='wb') as f: f.write(constant_graph.SerializeToString()) sess.close()
pb的读取
import tensorflow as tf from tensorflow.python.framework import graph_util import numpy as np logdir='./checkpoints/' output_graph_path = logdir+'expert-graph.pb' with tf.Session() as sess: tf.global_variables_initializer().run() output_graph_def = tf.GraphDef() with open(output_graph_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="") input_x = sess.graph.get_tensor_by_name("input:0") print(input_x) y=input_x out=sess.run(y,{input_x:np.random.random([1,3])}) print(out[:10])#取出out的前10行, w = sess.graph.get_tensor_by_name("w:0") print(w)