Tensorflow(2)保存模型与恢复

###一、数据模型的保存
使用saver类,自动保存tensorflow的图结构(***.ckpt.meta),参数取值(***.ckpt.data),以及目录下的文件列表(***.ckpt.index),还有一个checkpoint文件。

  • 定义变量
  • 变量操作
  • 变量初始化
  • 构建saver类
  • 使用保存模型参数到文件
import tensorflow as tf

v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2=tf.Variable(tf.constant(4.0,shape=[1]),name='v2')
v3=tf.Variable(tf.constant(4.0,shape=[1]),name='v3')

result1=v1+v2
result2=result1+v3

init_op=tf.global_variables_initializer()
saver=tf.train.Saver()

with tf.Session() as sess:
	sess.run(init_op)
	saver.save(sess,"codes/tensorflow_test/model/model.ckpt")

###二、参数恢复
参数恢复程序中必须已经定义了参数,并且参数名称要和定义的参数名字一致。可以使用tf.Variables(),tf.global_variables()获取参数名字(适用于修改别人程序的时候)。
note:参数名称必须一致(name=“ ”),具体的操作(result,result2)可以修改

for variables in tf.global_variables():
	print(variables.name,variables.shape)

程序

import tensorflow as tf

v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2=tf.Variable(tf.constant(1.0,shape=[1]),name='v2')
v3=tf.Variable(tf.constant(1.0,shape=[1]),name='v3')

result1=v1+v2
result2=v1+v2+v3

saver=tf.train.Saver()
#saver=tf.train.Saver([v1,v2] )#部分参数恢复,这个时候需要注释v3,或者给v3加上初始化操作。
with tf.Session() as sess:
	saver.restore(sess,"codes/tensorflow_test/model/model.ckpt")
	# for variables in tf.global_variables():
	# 	print(variables.name)
	print(sess.run(result1))
	print(sess.run(result2))
	print("sucessful\n")

###三、计算图与参数同时恢复

  • 计算图和参数同时恢复的时候,不需要定义变量,也不需要变量初始化
  • 变量在导入图结构之后就已经获取了
  • 可以使用原来的图结构中的操作,这个时候需要指定运算名称
  • 也可以自定义新的操作
import tensorflow as tf

saver=tf.train.import_meta_graph("/home/wuwei/codes/tensorflow_test/model/model.ckpt.meta")
# for variables in tf.global_variables():  ###get name and shape
#     print(variables.name, variables.shape)
result3=tf.get_default_graph().get_tensor_by_name("v1:0")+tf.get_default_graph().get_tensor_by_name("v2:0")
with tf.Session() as sess:
    saver.restore(sess,"/home/wuwei/codes/tensorflow_test/model/model.ckpt")
    print(sess.run(result3))  ### our op
    print("**********************")
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add_1:0")))###original op
    print("sucessful\n")

###四、另一种保存tensorflow模型的操作,整个图和参数设置为常量,保存在一个文件中
在使用tensorRT进行推理的时候,需要使用到这种模型。

##保存模型save_model.py

import tensorflow as tf
from tensorflow.python.framework import graph_util


v1=tf.Variable(tf.constant(5.0,shape=[1],name="v1"))
v2=tf.Variable(tf.constant(4.0,shape=[1],name="v2"))
v3=tf.Variable(tf.constant(3.0,shape=[1],name="v3"))

result=v1+v2
print(result.name)
init_op=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    graph_def=tf.get_default_graph().as_graph_def()

    output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add'])
    with tf.gfile.GFile("/home/wuwei/codes/tensorflow_test/model/combined_model.pb",'wb') as f:
        f.write(output_graph_def.SerializeToString()) 

##restore3.py

import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename="/home/wuwei/codes/tensorflow_test/model/combined_model.pb"
    with gfile.FastGFile(model_filename,'rb') as f:
        graph_def=tf.GraphDef()
        graph_def.ParseFromString(f.read())
    result=tf.import_graph_def(graph_def,return_elements=["add:0"])
    print(sess.run(result))

猜你喜欢

转载自blog.csdn.net/weixin_40100431/article/details/82860478
今日推荐