TensorFlow中查看checkpoint文件中的变量名和对应值

在加载模型时, 需要知道checkpoint中变量名称,一下代码可以查看TensorFlow中checkpoint文件中的变量名:

#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os
from tensorflow.python import pywrap_tensorflow
model_dir = "Saved_model"
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key, end=' ')
    print(reader.get_tensor(key))

代码说明:

在上述代码执行文件中,有一个保存的路径为“Saved_model/model.ckpt”的模型,具体信息如下图suo

你可以根据你自己的路径来修改其中的参数即可查看你自己的checkpoint文件中的变量名和变量值。 

猜你喜欢

转载自blog.csdn.net/Sophia_11/article/details/84931544