TensorFlow模型持久化

保存模型

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.,shape=[1]),name='v1')
v2 = tf.Variable(tf.constant(2.,shape=[1]),name='v2')
result = v1+v2

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #会得到3个文件:
    #model.ckpt.meta:保存计算图的结构
    #model.ckpt:保存计算图的变量
    #ckeckpoint:保存这个目录下的所有计算图文件
    saver.save(sess,'./model.ckpt')
    #将模型的结构保存为json格式以便于查看
    #saver.export_meta_graph('model.ckpt.meta.json',as_text=True)

加载模型

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.,shape=[1]),name='v1')
v2 = tf.Variable(tf.constant(2.,shape=[1]),name='v2')
result = v1+v2

saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess,'model.ckpt')
    print(sess.run(result))

若不想先定义好图再加载,可以

import tensorflow as tf
#将已有的图加载为默认图
saver = tf.train.import_meta_graph('model.ckpt.meta')

with tf.Session() as sess:
    saver.restore(sess,'model.ckpt')
    #根据名字获取变量
    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

加载部分变量:

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.,shape=[1]),name='v1')
v2 = tf.Variable(tf.constant(2.,shape=[1]),name='v2')
result = v1+v2

#加载部分变量
saver = tf.train.Saver([v2])

with tf.Session() as sess:
    saver.restore(sess,'model.ckpt')
    #此时会报错
    print(sess.run(v1))

若变量名字不一样,可以:

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.,shape=[1]),name='other-v1')
v2 = tf.Variable(tf.constant(2.,shape=[1]),name='other-v2')
result = v1+v2

saver = tf.train.Saver({'v1':v1})

with tf.Session() as sess:
    saver.restore(sess,'model.ckpt')
    print(sess.run(v1))
    #报错
    print(sess.run(v2))

加载滑动平均

import tensorflow as tf

v = tf.Variable(0,dtype=tf.float32,name='v')
ema = tf.train.ExponentialMovingAverage(0.99)
ema_op = ema.apply(v)

#会打印出:{'v/ExponentialMovingAverage':v}
print(ema.variables_to_restore())
saver = tf.train.Saver(ema.variables_to_restore())

with tf.Session() as sess:
    #saver.save(sess,'model.ckpt')
    saver.restore(sess,'model.ckpt')

将图和变量同时保存

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2 = tf.Variable(tf.constant(2.0,shape=[1]),name='v2')
result = v1 + v2

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #获取图的定义
    graph_def = tf.get_default_graph().as_graph_def()
    #图和变量合并
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,graph_def,['add'])
    with tf.gfile.GFile('model.pb','wb') as f:
        #序列化
        f.write(output_graph_def.SerializeToString())

从文件中恢复

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    #从序列化中获取图的定义
    with gfile.FastGFile('model.pb','rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    #获取图和相应的变量
    result = tf.import_graph_def(graph_def,return_elements=['add:0'])
    print(sess.run(result))

获取模型中变量的值

import tensorflow as tf

reader = tf.train.NewCheckpointReader('./model.ckpt')
global_variables = reader.get_variable_to_shape_map()

for var_name in global_variables:
    #变量名
    print(var_name)
    #变量维度
    print(global_variables[var_name])
    #变量值
    print(reader.get_tensor(var_name))

定时保存模型

#若global_step为1000,则将模型保存为model.ckpt-1000
saver.save(sess,'model.ckpt',global_step=global_step)

从多个模型中找到最新的模型

#根据path文件夹找到最新的ckpt模型,并返回文件名
ckpt = tf.train.get_checkpoint_state(path)
#打印出
#model_checkpoint_path: "./model/model.ckpt-29001"
#all_model_checkpoint_paths: "./model/model.ckpt-25001"
#all_model_checkpoint_paths: "./model/model.ckpt-26001"
#all_model_checkpoint_paths: "./model/model.ckpt-27001"
#all_model_checkpoint_paths: "./model/model.ckpt-28001"
#all_model_checkpoint_paths: "./model/model.ckpt-29001"
print(ckpt)
#打印出
#model_checkpoint_path: "./model/model.ckpt-29001"
print(ckpt.model_checkpoint_path)

猜你喜欢

转载自blog.csdn.net/a13602955218/article/details/80710788