保存模型
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.,shape=[1]),name='v1')
v2 = tf.Variable(tf.constant(2.,shape=[1]),name='v2')
result = v1+v2
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#会得到3个文件:
#model.ckpt.meta:保存计算图的结构
#model.ckpt:保存计算图的变量
#ckeckpoint:保存这个目录下的所有计算图文件
saver.save(sess,'./model.ckpt')
#将模型的结构保存为json格式以便于查看
#saver.export_meta_graph('model.ckpt.meta.json',as_text=True)
加载模型
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.,shape=[1]),name='v1')
v2 = tf.Variable(tf.constant(2.,shape=[1]),name='v2')
result = v1+v2
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,'model.ckpt')
print(sess.run(result))
若不想先定义好图再加载,可以
import tensorflow as tf
#将已有的图加载为默认图
saver = tf.train.import_meta_graph('model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,'model.ckpt')
#根据名字获取变量
print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
加载部分变量:
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.,shape=[1]),name='v1')
v2 = tf.Variable(tf.constant(2.,shape=[1]),name='v2')
result = v1+v2
#加载部分变量
saver = tf.train.Saver([v2])
with tf.Session() as sess:
saver.restore(sess,'model.ckpt')
#此时会报错
print(sess.run(v1))
若变量名字不一样,可以:
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.,shape=[1]),name='other-v1')
v2 = tf.Variable(tf.constant(2.,shape=[1]),name='other-v2')
result = v1+v2
saver = tf.train.Saver({'v1':v1})
with tf.Session() as sess:
saver.restore(sess,'model.ckpt')
print(sess.run(v1))
#报错
print(sess.run(v2))
加载滑动平均
import tensorflow as tf
v = tf.Variable(0,dtype=tf.float32,name='v')
ema = tf.train.ExponentialMovingAverage(0.99)
ema_op = ema.apply(v)
#会打印出:{'v/ExponentialMovingAverage':v}
print(ema.variables_to_restore())
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
#saver.save(sess,'model.ckpt')
saver.restore(sess,'model.ckpt')
将图和变量同时保存
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2 = tf.Variable(tf.constant(2.0,shape=[1]),name='v2')
result = v1 + v2
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#获取图的定义
graph_def = tf.get_default_graph().as_graph_def()
#图和变量合并
output_graph_def = graph_util.convert_variables_to_constants(
sess,graph_def,['add'])
with tf.gfile.GFile('model.pb','wb') as f:
#序列化
f.write(output_graph_def.SerializeToString())
从文件中恢复
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
#从序列化中获取图的定义
with gfile.FastGFile('model.pb','rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#获取图和相应的变量
result = tf.import_graph_def(graph_def,return_elements=['add:0'])
print(sess.run(result))
获取模型中变量的值
import tensorflow as tf
reader = tf.train.NewCheckpointReader('./model.ckpt')
global_variables = reader.get_variable_to_shape_map()
for var_name in global_variables:
#变量名
print(var_name)
#变量维度
print(global_variables[var_name])
#变量值
print(reader.get_tensor(var_name))
定时保存模型
#若global_step为1000,则将模型保存为model.ckpt-1000
saver.save(sess,'model.ckpt',global_step=global_step)
从多个模型中找到最新的模型
#根据path文件夹找到最新的ckpt模型,并返回文件名
ckpt = tf.train.get_checkpoint_state(path)
#打印出
#model_checkpoint_path: "./model/model.ckpt-29001"
#all_model_checkpoint_paths: "./model/model.ckpt-25001"
#all_model_checkpoint_paths: "./model/model.ckpt-26001"
#all_model_checkpoint_paths: "./model/model.ckpt-27001"
#all_model_checkpoint_paths: "./model/model.ckpt-28001"
#all_model_checkpoint_paths: "./model/model.ckpt-29001"
print(ckpt)
#打印出
#model_checkpoint_path: "./model/model.ckpt-29001"
print(ckpt.model_checkpoint_path)