Tensorflow【实战Google深度学习框架】TensorFlow模型的保存与恢复加载

我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载。

总结一下Tensorflow常用的模型保存方式。

保存checkpoint模型文件(.ckpt)

首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型。

模型保存

使用tf.train.Saver()来保存模型文件非常方便,下面是一个简单的例子:

import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    op = tf.add(xy, b, name='op_to_store')

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    path = os.path.dirname(os.path.abspath(ckpt_file_path))
    if os.path.isdir(path) is False:
        os.makedirs(path)

    tf.train.Saver().save(sess, ckpt_file_path)
    
    # test
    feed_dict = {x: 2, y: 3}
    print(sess.run(op, feed_dict))

程序生成并保存四个文件(在版本0.11之前只会生成三个文件:checkpoint, model.ckpt, model.ckpt.meta)

checkpoint 文本文件,记录了模型文件的路径信息列表
model.ckpt.data-00000-of-00001 网络权重信息
model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf
以上是tf.train.Saver().save()的基本用法,save()方法还有很多可配置的参数:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)

加上global_step参数代表在每1000次迭代后保存模型,会在模型文件后加上"-1000",model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data-1000-00000-of-00001

每1000次迭代保存一次模型,但是模型的结构信息文件不会变,就只用1000次迭代时保存一下,不用相应的每1000次保存一次,所以当我们不需要保存meta文件时,可以加上write_meta_graph=False参数,如下:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:

import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
    sess = tf.Session()
    saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta')  # 加载模型结构
    saver.restore(sess, tf.train.latest_checkpoint('./ckpt'))  # 只需要指定目录就可以恢复所有变量信息

    # 直接获取保存的变量
    print(sess.run('b:0'))

    # 获取placeholder变量
    input_x = sess.graph.get_tensor_by_name('x:0')
    input_y = sess.graph.get_tensor_by_name('y:0')
    # 获取需要进行计算的operator
    op = sess.graph.get_tensor_by_name('op_to_store:0')

    # 加入新的操作
    add_on_op = tf.multiply(op, 2)

    ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
    print(ret)

完整代码

github!!!!!!!!!!!!!

发布了686 篇原创文章 · 获赞 195 · 访问量 6万+

猜你喜欢

转载自blog.csdn.net/weixin_43838785/article/details/104511582
今日推荐