查看tensorflow 模型文件的节点信息

直接在代码中打印tensor的名称信息,比如:

x = tf . layers . Input ( shape = [ 32 ] )
print ( x )
y = tf . layers . dense ( x , 16 , activation = tf . nn . softmax )
print ( y )

输出:
Tensor ( "input_layer_1:0" , shape = ( ? , 32 ) , dtype = float32 )
Tensor ( "dense/Softmax:0" , shape = ( ? , 16 ) , dtype = float32 )


1、查看checkpoint 节点信息: 代码如下

from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join( "checkpoint-00454721")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
        print("tensor_name: ", key)

        #print(reader.get_tensor(key))


2、查看checkpoint 节点信息:调用tensorflow 工具,命令如下:

inspect_checkpoint.py --file_name=checkpoint-00454721


以下转自:https://www.cnblogs.com/bonelee/p/8462578.html

查看tensorflow pb模型文件的节点信息:

复制代码
import tensorflow as tf
with tf.Session() as sess:
    with open('./quantized_model.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read()) 
        print graph_def
        
复制代码

效果:

复制代码
# ...
node {
  name: "FullyConnected/BiasAdd"
  op: "BiasAdd"
  input: "FullyConnected/MatMul"
  input: "FullyConnected/b/read"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "data_format"
    value {
      s: "NHWC"
    }
  }
}
node {
  name: "FullyConnected/Softmax"
  op: "Softmax"
  input: "FullyConnected/BiasAdd"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
library {
}
复制代码

 

 

参考:https://tang.su/2017/01/export-TensorFlow-network/

https://github.com/tensorflow/tensorflow/issues/15689

一些核心代码:

复制代码
import tensorflow as tf
with tf.Session() as sess:
    with open('./graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read()) 
        print graph_def
        output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
        print(sess.run(output))
复制代码

 

 

This is part of my Tensorflow frozen graph, I have named the input and output nodes.

>>> g.ParseFromString(open('frozen_graph.pb','rb').read())
>>> g
node {
  name: "input"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }
        dim {
          size: 68
        }
      }
    }
  }
}
...
node {
  name: "output"
  op: "Softmax"
  input: "add"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}

I ran this model by the following code
(CELL is name of directory where my file is located)

final String MODEL_FILE = "file:///android_asset/" + CELL + "/optimized_graph.pb" ;
final String INPUT_NODE = "input" ;
final String OUTPUT_NODE = "output" ;
final int[] INPUT_SIZE = {1,68} ;
float[] RESULT = new float[8];

inferenceInterface = new TensorFlowInferenceInterface();
inferenceInterface.initializeTensorFlow(getAssets(),MODEL_FILE) ;
inferenceInterface.fillNodeFloat(INPUT_NODE,INPUT_SIZE,input);

and finally

inferenceInterface.readNodeFloat(OUTPUT_NODE,RESULT);

猜你喜欢

转载自blog.csdn.net/haima1998/article/details/80297710
今日推荐