Tensorflow模型通过ckpt获取参数

ckpt是Tensorflow保存的已知模型,获取参数名:

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
 
model_dir = "./train_results/checkpoint/"
 
ckpt = tf.train.get_checkpoint_state(model_dir)
ckpt_path = ckpt.model_checkpoint_path
 
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
param_dict = reader.get_variable_to_shape_map()
 
for key, val in param_dict.items():
    try:
        print(key, val)
    except:
        pass

model_dir:是存放ckpt模型的路径

发布了128 篇原创文章 · 获赞 132 · 访问量 17万+

猜你喜欢

转载自blog.csdn.net/yql_617540298/article/details/89945789