Tensorflow模型保存与提取

    1、tf.train.Saver类

    tensorflow主要通过train.Saver类来保存和提取模型,该类定义在tensorflow/python/training/saver.py中

    Saver的初始化参数如下:

__init__(self,
    var_list=None, #一个字典,指定保存的对象列表,默认为None,即保存所有可保存对象
    reshape=False, #当为True时,表示从一个checkpoint中恢复参数时允许参数shape发生变化
    sharded=False, #是否将变量轮循放到所有设备上
    max_to_keep=5, #保存模型时会滚动更新,该值指定保存的模型个数
    keep_checkpoint_every_n_hours=10000.0, #按时间间隔来保存模型
    name=None,
    restore_sequentially=False, #是否按顺序恢复变量
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

    2、保存模型与参数

    过程为:定义计算图、执行初始化和计算、保存计算后的计算图和参数

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
 
saver = tf.train.Saver()
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, "model/model.ckpt")

    其中,save的可选参数如下:

save(
    sess, #必须是加载了计算图、且变量已经初始化的session
    save_path, #模型的路径
    global_step=None, #如果提供,会添加在save_path后面,以区分不同阶段的模型
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True
)

    执行后,会在model文件夹下得到四个文件:

    

    其中,checkpoint保存模型的列表,model.ckpt.meta文件保存了计算图的结构信息,model.cpkt.index和model.cpkt.data保存的是参数名和参数值。(旧版的tf保存的数据只有一个cpkt文件,而新版的tf把它分成了两个文件)

    3、模型的读取与恢复

    根据checkpoint文件来寻找最新的模型:

ckpt = tf.train.get_checkpoint_state('./model/') #锁定最新模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') #model_checkpoint_path: ./model/model.ckpt-4
with tf.Session() as sess:
    saver.restore(sess,ckpt.model_checkpoint_path)  

    恢复模型的过程为:定义计算图、恢复参数、执行计算,相当于用restore取代了初始化    

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
 
saver = tf.train.Saver()
 
with tf.Session() as sess:
    saver.restore(sess, "model/model.ckpt")
    print(sess.run(result)) 

    也可以直接恢复checkpoint中的计算图,图中的变量通过变量名来获得

saver = tf.train.import_meta_graph("model/model.ckpt.meta")
 
with tf.Session() as sess:
    saver.restore(sess, "model/model.ckpt") 
    result = tf.get_default_graph().get_tensor_by_name("add:0")
    print(sess.run(result))

    在恢复时,默认把变量值指定给同名的变量,若计算图中的变量在checkpoint中不存在,则会报NotFoundError。(checkpoint中的数据能多不能少)

    也可以手动把checkpoint中的变量值指定给计算图中不同名的变量:

u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = u1 + u2
saver = tf.train.Saver({"v1": u1, "v2": u2}) #"v1"、"v2"是ckpt中的变量名,u1、u2是当前环境中的变量
 
with tf.Session() as sess:
    saver.restore(sess, "model/model.ckpt")
    print(sess.run(result))

    其中,可以通过以下方法来读取checkpoint中的变量名和变量值:

ckpt = tf.train.get_checkpoint_state('./model')
checkpoint_path = ckpt.model_checkpoint_path 

# Read data from checkpoint file  
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)  
var_to_shape_map = reader.get_variable_to_shape_map() 

# Print tensor name and values  
for key in var_to_shape_map:  
     print("tensor_name: ", key)   
     print("tensor_value: ", reader.get_tensor(key))

猜你喜欢

转载自blog.csdn.net/xiezongsheng1990/article/details/81011115