tf.graph()的声明,保存与调用

# -*- coding: utf-8 -*-  
import tensorflow as tf
from tensorflow.python.platform import gfile

#图的声明
g1 = tf.Graph()#声明图g1
with g1.as_default():
    # 需要加上名称,在读取pb文件的时候,是通过name和下标来取得对应的tensor的
    c1 = tf.constant(4.0, name='c1')

g2 = tf.Graph()#声明图g2
with g2.as_default():
    c2 = tf.constant(20.0,name='c2')

with tf.Session(graph=g1) as sess1:
    print(sess1.run(c1))#4.0
with tf.Session(graph=g2) as sess2:
    print(sess2.run(c2))#20.0

#图的保存
# g1的图定义,包含pb的path, pb文件名,是否是文本默认False
tf.train.write_graph(g1.as_graph_def(), '.', 'graph1.pb', False)#保存第一个图g1.'.',是默认保存在当前目录下
tf.train.write_graph(g2.as_graph_def(), '.', 'graph2.pb', False)#保存第二个图g2

#图的调用
# load graph
with gfile.FastGFile("./graph2.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

sess = tf.Session()
c1_tensor = sess.graph.get_tensor_by_name("c2:0")
c1 = sess.run(c1_tensor)
print(c1)#20

猜你喜欢

转载自blog.csdn.net/weixin_38145317/article/details/89947090