tensorflow保存的模型

1. tensorflow保存了3个文件

model.ckpt-10000.data-00000-of-00001                         
model.ckpt-10000.index                                    
model.ckpt-10000.meta 
  • 一般调用生成的模型,直接model.ckpt-1000这样的格式即可
  • data中存储的是模型的变量值
  • index 存储的是tensor名称
  • meta 存储的是graph结构,包括 GraphDef, SaverDef等,当存在meta file,我们可以不在文件中定义模型,也可以运行,而如果没有meta file,我们需要定义好模型,再加载data file,得到变量值

2. 计算模型中的参数量

  • keras是可以直接输出每层的结构,并且在最后自动计算参数量
  • 普通的tensorflow可以调用训练生成的模型,计算参数量
from tensorflow.python import pywrap_tensorflow
import os
import numpy as np

checkpoint_path = os.path.join("models_pretrained/", "model.ckpt-82798")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
total_parameters = 0
for key in var_to_shape_map:#list the keys of the model
    # print(key)
    # print(reader.get_tensor(key))
    shape = np.shape(reader.get_tensor(key))  #get the shape of the tensor in the model
    shape = list(shape)
    # print(shape)
    # print(len(shape))
    variable_parameters = 1
    for dim in shape:
        # print(dim)
        variable_parameters *= dim
    # print(variable_parameters)
    total_parameters += variable_parameters

print(total_parameters)
  1. 计算模型的浮点运算量
    指导方法
    但是还没有成功跑通,暂留!

  2. 日志输出

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='1' # 这是默认的显示等级,显示所有信息
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error 
os.environ["TF_CPP_MIN_LOG_LEVEL"]='3' # 只显示 Error  
发布了98 篇原创文章 · 获赞 9 · 访问量 4万+

猜你喜欢

转载自blog.csdn.net/qq_40168949/article/details/103510593
今日推荐