《TensorFlow:实战Google深度学习框架》——5.4 模型持久化(模型保存、模型加载)

目录

1、持久化代码实现

2、加载保存的TensorFlow模型

3、加载部分变量

4、加载变量时重命名


1、持久化代码实现

TensorFlow提供了一个非常简单的API来保存和还原一个神经网络模型。这个API就是tf.train.Saver类。一下代码给出了保存TensorFlow计算图的方法。

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import tensorflow as tf

# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    # saver.sabe函数保存到“Saved_model/model.ckpt”
    saver.save(sess, "Saved_model/model.ckpt")

解析:

  • 在这段代码中,通过saver.save 函数将TensorFlow模型保存到了“Saved_model/model.ckpt”文件中。TensorFlow模型一般会存在后缀为.ckpt的文件中 。
  • 虽然以上程序只指定了 一个文件路径,但是在这个文件目录下会出现三个文件:
  1. 第一个文件为model.ckpt.meta,它保存了 TensorFlow计算图的结构
  2. 第二个文件为model.ckpt,这个文件中保存了TensorFlow 程序中每一个变量的取值。
  3. 第三个文件为checkpoint文件,这个文件中保存了一个目录下所有的模型文件列表

2、加载保存的TensorFlow模型

以下代码中给出了加载这个已经保存的TensorFlow模型的方法 

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import tensorflow as tf

# 保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

saver = tf.train.Saver()

# 加载保存的模型,加载全部模型
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(result))

上述代码输出为:

解析:

这段加载模型的代码基本上和保存模型的代码是一样的。在加载模型的程序中也是先定义了TensorFlow计算图上的所有运算,并声明了 一个tf.train.Saver类。两段代码唯一不同的是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过己经保存的模型加载进来。


如果不希望重复定义图上的运算,也可以直接加载已经持久化的图。一下代码给出一个样例:

import tensorflow as tf

# 加载持久化的图
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")

with tf.Session() as sess:
    saver.restore(sess,"Saved_model/model.ckpt")
    # 通过张量的名称来获取张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

这段代码与上述代码达到的效果相同。是两种方式加载模型。


3、加载部分变量

为了保存或者加载部分变量,在声明 tf.train.Saver 类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train. Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来 。如果运行修改后只加载了v1的代码会得到变革未初始化的错误:

tensorflow.python.framework.errors.FailedPreconditionError:Attempting touse uninitialized value v2

4、加载变量时重命名

tf.train.Saver类也支持在保存或者加载时给变量重命名。下面给出了一个简单的样例程序说明变量重命名是如何被使用的。

#!/usr/bin/env python
# -*- coding:utf-8 -*-

import tensorflow as tf

# tf.reset_default_graph()

# 声明变量
V1 = tf.Variable(tf.constant(1.0, shape=[1]), name="a1")
V2 = tf.Variable(tf.constant(2.0, shape=[1]), name="a2")
# result = V1 + V2

# 这里要注意,checkpoint中的变量名的问题,不然就会出现问题
saver = tf.train.Saver({"Variable": V1, "Variable_1": V2})

# 加载保存的模型,加载全部模型
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print(sess.run(V1+V2))

上述关于查看checkpoint文件中的变量名的问题,请参考博文TensorFlow中查看checkpoint文件中的变量名和对应值

猜你喜欢

转载自blog.csdn.net/Sophia_11/article/details/84929304