打印tensorflow每一层结构

import tensorflow as tf

sess=tf.Session
with tf.Graph().as_default():
with tf.gfile.FastGFile(‘*.pb’,’rb’) as modelfile:
graph_def=tf.GraphDef()
graph_def.ParseFromString(modelfile.read())
tf.import_graph_def(graph_def)
[print(n.name) for n in tf.get_default_graph().as_graph_def().node]

猜你喜欢

转载自blog.csdn.net/intjun/article/details/82215738