用一个非常简单的例子学习导出和加载模型;
导出
写一个y=a*x+b的运算,然后保存graph;
import tensorflow as tf #from tensorflow.python.framework.graph_util import convert_variables_to_constants with tf.Session() as sess: a = tf.Variable(5.0, name='a') x = tf.Variable(6.0, name='x') b = tf.Variable(3.0, name='b') y = tf.add(tf.multiply(a,x),b, name="y") sess.run(tf.initialize_all_variables()) print (a.eval()) # 5.0 print (x.eval()) # 6.0 print (b.eval()) # 3.0 print (y.eval()) # 33.0 #graph = convert_variables_to_constants(sess, sess.graph_def, ["y"]) #writer = tf.summary.FileWriter("logs/", graph) tf.train.write_graph(graph, 'models/', 'test_graph.pb', as_text=False)
运行
在models目录下生成了test_graph.pb;
加载
只加载,获取各个变量的值
import tensorflow as tf from tensorflow.python.platform import gfile with gfile.FastGFile("models/test_graph.pb", 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, return_elements=['a:0', 'x:0', 'b:0','y:0']) #print(output) with tf.Session() as sess: result = sess.run(output) print (result)
运行看以看到原本保存的结果(因为几个变量都是常量)
加载的时候修改变量值
5*2+3=13,结果正确
运行时修改变量值
加载时用一个占位符替掉x常量,在session运行时再给占位符填值;
5*3+3=18,也正确
修改计算结果
偷偷把结果给改了会怎么样?
呵呵,不知原因为何;以后钻进代码了再说;
参考:
https://www.sohu.com/a/233679628_468681
http://blog.163.com/wujiaxing009@126/blog/static/7198839920174125748893/