###一、数据模型的保存
使用saver类,自动保存tensorflow的图结构(***.ckpt.meta),参数取值(***.ckpt.data),以及目录下的文件列表(***.ckpt.index),还有一个checkpoint文件。
- 定义变量
- 变量操作
- 变量初始化
- 构建saver类
- 使用保存模型参数到文件
import tensorflow as tf
v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2=tf.Variable(tf.constant(4.0,shape=[1]),name='v2')
v3=tf.Variable(tf.constant(4.0,shape=[1]),name='v3')
result1=v1+v2
result2=result1+v3
init_op=tf.global_variables_initializer()
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess,"codes/tensorflow_test/model/model.ckpt")
###二、参数恢复
参数恢复程序中必须已经定义了参数,并且参数名称要和定义的参数名字一致。可以使用tf.Variables(),tf.global_variables()获取参数名字(适用于修改别人程序的时候)。
note:参数名称必须一致(name=“ ”),具体的操作(result,result2)可以修改
for variables in tf.global_variables():
print(variables.name,variables.shape)
程序
import tensorflow as tf
v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2=tf.Variable(tf.constant(1.0,shape=[1]),name='v2')
v3=tf.Variable(tf.constant(1.0,shape=[1]),name='v3')
result1=v1+v2
result2=v1+v2+v3
saver=tf.train.Saver()
#saver=tf.train.Saver([v1,v2] )#部分参数恢复,这个时候需要注释v3,或者给v3加上初始化操作。
with tf.Session() as sess:
saver.restore(sess,"codes/tensorflow_test/model/model.ckpt")
# for variables in tf.global_variables():
# print(variables.name)
print(sess.run(result1))
print(sess.run(result2))
print("sucessful\n")
###三、计算图与参数同时恢复
- 计算图和参数同时恢复的时候,不需要定义变量,也不需要变量初始化
- 变量在导入图结构之后就已经获取了
- 可以使用原来的图结构中的操作,这个时候需要指定运算名称
- 也可以自定义新的操作
import tensorflow as tf
saver=tf.train.import_meta_graph("/home/wuwei/codes/tensorflow_test/model/model.ckpt.meta")
# for variables in tf.global_variables(): ###get name and shape
# print(variables.name, variables.shape)
result3=tf.get_default_graph().get_tensor_by_name("v1:0")+tf.get_default_graph().get_tensor_by_name("v2:0")
with tf.Session() as sess:
saver.restore(sess,"/home/wuwei/codes/tensorflow_test/model/model.ckpt")
print(sess.run(result3)) ### our op
print("**********************")
print(sess.run(tf.get_default_graph().get_tensor_by_name("add_1:0")))###original op
print("sucessful\n")
###四、另一种保存tensorflow模型的操作,整个图和参数设置为常量,保存在一个文件中
在使用tensorRT进行推理的时候,需要使用到这种模型。
##保存模型save_model.py
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1=tf.Variable(tf.constant(5.0,shape=[1],name="v1"))
v2=tf.Variable(tf.constant(4.0,shape=[1],name="v2"))
v3=tf.Variable(tf.constant(3.0,shape=[1],name="v3"))
result=v1+v2
print(result.name)
init_op=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
graph_def=tf.get_default_graph().as_graph_def()
output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add'])
with tf.gfile.GFile("/home/wuwei/codes/tensorflow_test/model/combined_model.pb",'wb') as f:
f.write(output_graph_def.SerializeToString())
##restore3.py
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename="/home/wuwei/codes/tensorflow_test/model/combined_model.pb"
with gfile.FastGFile(model_filename,'rb') as f:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
result=tf.import_graph_def(graph_def,return_elements=["add:0"])
print(sess.run(result))