查看 tensorflow 的模型保存的检查点 checkpoint 文件

  1. 模型保存
    官网说明:https://www.tensorflow.org/guide/saved_model?hl=zh-cN
>>> import tensorflow as tf
>>> v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
>>> v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer)
>>> inc_v1 = v1.asign(v1+1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'RefVariable' object has no attribute 'asign'
>>> inc_v1 = v1.assign(v1+1)
>>> inc_v2 = v2.assign(v2-1)
>>> init_op = tf.global_variables_initializer()
>>> saver = tf.train.Saver()
>>> with tf.Session() as sess:
...     sess.run(init_op)
...     inc_v1.op.run()
...     inc_v2.op.run()
...     save_path = saver.save(sess,'/Users/jiweiwang/temp/model.ckpt')
...     print("Model saved in path: %s" % save_path)
  1. 查看检查点 checkpoint 文件
    print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, all_tensor_names=False, count_exclude_pattern="")可以查看 TensorFlow 保存模型的参数名称、模型中的参数值、参数总量
    官方说明:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.py
    参数:
    file_name:检查点文件的名称.
    tensor_name:指定要打印检查点文件中特定张量的名称,不打印指定名称,则设为 None,需要注意如果只想查看指定张量,那么位置参数 all_tensors 与关键词参数 all_tensor_name 都需要设置 False,否则还会把所有的张量都打印出来.
    all_tensors:布尔型参数,True 表示打印检查点文件中保存的所有张量.
    all_tensor_names: 布尔型关键字参数,默认为 False,表示是否打印所有变量的名称,如果只想把所有的张量名称打印出来,而不打印具体的张量值,需要设置位置参数 all_tensors 为 False.
    count_exclude_pattern: 正则化字符串,用于计数是排除匹配指定的张量.
    (1)打印指定的张量名称及其数值
>>> from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
>>> print_tensors_in_checkpoint_file('/Users/jiweiwang/temp/model.ckpt', None, True) # 需要注意:参数 file_name 传入的模型不带 .meta、.index、.data-00000-of-00001 等拓展名
tensor_name:  v1
[1. 1. 1.]
tensor_name:  v2
[-1. -1. -1. -1. -1.]
# Total number of params: 8

猜你喜欢

转载自blog.csdn.net/sdnuwjw/article/details/111321954
今日推荐