Tensorflow 实战Google深度学习框架——学习笔记(五)TensorFlow持久化

TensorFlow模型持久化

模型持久化的目的:为了让训练完的模型可以在下次使用 TensorFlow提供了一个非常简单的API来保存和还原一个神经网络,这个API类就是tf.train.Saver类。以下是保存TensorFlow计算图的方法。

变量的持久化

1、保存变量(实际上也是保存计算图)

saver.save(sess, “../save_data/model.ckpt”)会把计算图和变量都保存下来

import tensorflow as tf
tf.reset_default_graph()  # 清空当前的变量,如果没有,每次保存数据都会是新的变量(不会重复保存在同一个名字中)
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(3.0, shape=[1]), name="v2")
result = v1 + v2

# 使用Saver()类生成一个对象用于保存
saver = tf.train.Saver()

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    # 保存的是会话,实际上是模型,需要指定详细文件名
    saver.save(sess, "../save_data/model.ckpt")
    print(sess.run(result))
#     print(result.eval())
[ 4.]

保存后出现了四个文件,第一个文件是model.ckpt.meta,里面保存了TensorFlow计算图的结构;第二个为model.ckpt.data;保存了每一个变量的值。第三个为checkpoint文件,保存了一个目录下所有的模型文件列表。最后一个ckpt.index还不知道是什么,待补充

2、加载保存的TensorFlow变量

1.定义TensorFlow计算图上的所有变量
2.声明一个tf.train.Saver类
注意下面没有定义result,但是输出结果是有一个张量,尝试获取result的值会报错,因为只有变量没有运算,不能得出值。另外这个不保存在checkpoint中。

实际上保存的变量只有值,没有维度,所以需要再次声明变量

# 必须先清空
tf.reset_default_graph()
# 使用和保存模型中一样的方式来声明变量
# 这里的值定义什么都没有关系,会被保存的覆盖,把v2的值修改为2.0也会被保存的3.0覆盖
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "../save_data/model.ckpt")
    print("v1: %s" % v1.eval())
    print("v2: %s" % v2.eval())
#     print(result.eval())  # 没有定义运算也有
    print(result)
INFO:tensorflow:Restoring parameters from ../save_data/model.ckpt
v1: [ 1.]
v2: [ 3.]
Tensor("add:0", shape=(1,), dtype=float32)

注意:没有一个物理文件叫做model.ckpt。它是为检查点创建的文件名的前缀。用户只与前缀交互,而不与物理检查点文件交互。换句话说,你保存的文件没有model.ckpt这个文件,而是一个叫checkpoint的文件,如果要确定有没有保存该文件,只要检查checkpoint文件是否存在。`

3、单独存储和恢复某些变量

在不传入参数到Saver()中的时候,默认是存储或恢复所有变量。如果要指定保存/恢复某个变量,那就构造字典,键为名,值为值,如v2这个变量:{“v2”: v2}。可以多次传入参数到Saver()中,只有在运行restore()的时候才会全部执行,值发生改变。

下面演示单独加载一个变量

tf.reset_default_graph()
# v1值修改为5,和输出结果不一样说明并没有加载到,值加载了v2
v1 = tf.Variable(tf.constant(5.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 

saver = tf.train.Saver({"v2": v2})

with tf.Session() as sess:
    v1.initializer.run()  # 只初始化v1
    # # 保存的变量 v2 不用初始化,就像是做迁移学习直接用保存的值作为初始值,所以不要再初始化
    saver.restore(sess, "../save_data/model.ckpt")

    print("v1 : %s" % v1.eval())
    print("v2 : %s" % v2.eval())
INFO:tensorflow:Restoring parameters from ../save_data/model.ckpt
v1 : [ 5.]
v2 : [ 3.]

4、在保存和加载的时候重命名变量

保存的时候要重命名,只需要在传入字典的时候{“v1”: u1},这样一来v1就会被保存为u1
加载的时候也一样,传入字典{“u1”: v1}这样就名字变回来了。

重命名的目的之一就是方便使用变量的滑动平均值。如果在加载模型时直接将影子变量映射到变量自身,则在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。载入时,声明Saver类对象时通过一个字典将滑动平均值直接加载到新的变量中,saver = tf.train.Saver({“v/ExponentialMovingAverage”: v}),另通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典

5、查看checkpoint中的变量

要查看检查点中的变量,可以使用inspect_checkpoint库,特别print_tensors_in_checkpoint_file 函数

from tensorflow.python.tools import inspect_checkpoint as ckpt

# 展示所有变量,第一个参数的文件名,第二个参数是张量名,所有参数的时候all_tensors参数为True
ckpt.print_tensors_in_checkpoint_file("../save_data/model.ckpt", tensor_name="", all_tensors=True)
print('----------------')

# 只展示v1
ckpt.print_tensors_in_checkpoint_file("../save_data/model.ckpt", tensor_name="v1", all_tensors=False)
print('----------------')

# 只展示v2
ckpt.print_tensors_in_checkpoint_file("../save_data/model.ckpt", tensor_name="v2", all_tensors=False)
print('----------------')
tensor_name:  v1
[ 1.]
tensor_name:  v2
[ 3.]
----------------
tensor_name:  v1
[ 1.]
----------------
tensor_name:  v2
[ 3.]
----------------

模型的持久化

使用SavedModel来保存和加载您的模型变量、图和图的元数据。这是一种与语言无关的、可恢复的、可恢复的序列化格式,可以支持更高级别的系统和工具来生成、使用和转换TensorFlow模型。TensorFlow提供了几种与SavedModel交互的方法,包括tf.saved_model api、tf.estimator.Estimator和命令行界面。

1、简单保存

创建SavedModel的最简单方法是使用tf.saved_model.simple_save 函数

with tf.Session() as sess:
    tf.saved_model.simple_save(sess, 
                              export_dir,  # 保存的路径
                               inputs={"x": x, "y": y},
                              outputs={"z": z})

2、将计算图中的模型保存于一个文件中

使用convert_variables_to_constants函数

import tensorflow as tf
from tensorflow.python.framework import graph_util


v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(3.0, shape=[1]), name="v2")
result = v1 + v2

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分
    graph_def = tf.get_default_graph().as_graph_def()
    # 根据会话,计算过程定义一个输出对象,命名为add
    output_graph_def = graph_util.convert_variables_to_constants(sess,
                                                                graph_def,
                                                                ['add'])
    with tf.gfile.GFile("../datas/example.pb", "wb") as f:
        # 序列化为字符串后写入
        f.write(output_graph_def.SerializerToString())

3、载入包含变量及其取值的模型

import tensorflow as tf
from tensorflow.python.platform import gfile


with tf.Session() as sess:
    model_filename = "../datas/example.pb"
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print(sess.run(result))

猜你喜欢

转载自blog.csdn.net/m0_38106113/article/details/81545102
今日推荐