TensorFlow网络模型的保存与载入

版权声明:本文为博主原创,转载必须标明出处: https://blog.csdn.net/botao_li/article/details/85112081

保存

tf.train.Saver类保存模型

import tensorflow as tf


a = tf.Variable(2, dtype=tf.float32)
print(a)  # <tf.Variable 'Variable:0' shape=() dtype=float32_ref>
b = tf.multiply(a, 3)
print(b)  # Tensor("Mul:0", shape=(), dtype=float32)
c = tf.add(a, b)
print(c)  #  Tensor("Add:0", shape=(), dtype=float32)

# 建立saver
saver = tf.train.Saver()

with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.global_variables_initializer())

    # 循环更新变量数值
    for i in range(50):
        print(sess.run(c))
        sess.run(tf.assign_add(a, 1.5))

        # 保存模型
        # 下方代码中loop.ckpt为保存模型文件的名称
        # 模型名称中的.ckpt无特别意义,常用于表明保存checkpoint
        # global_step参数用于在保存文件名中加入迭代次数,默认保存最近4次迭代的模型
        saver.save(sess, "d:/temp/untitled/loop.ckpt", global_step=i)
        """
        在d:/temp/untitled/文件夹下生成13个文件
        checkpoint
        loop.ckpt-46.meta
        loop.ckpt-46.index
        loop.ckpt-46.data-00000-of-00001
        loop.ckpt-47.meta
        loop.ckpt-47.index
        loop.ckpt-47.data-00000-of-00001
        loop.ckpt-48.meta
        loop.ckpt-48.index
        loop.ckpt-48.data-00000-of-00001
        loop.ckpt-49.meta
        loop.ckpt-49.index
        loop.ckpt-49.data-00000-of-00001
        """

    # 保存模型
    # 下方代码中model为保存模型文件的名称
    saver.save(sess, "d:/temp/model")
    """
    在d:/temp/文件夹下生成4个文件
    checkpoint
    model.meta
    model.index
    model.data-00000-of-00001
    """

载入

仅载入模型数据

import tensorflow as tf

# 建立与保存时一致的模型
# 变量初始值随便,因为在下方会从模型保存文件中导入
a = tf.Variable(0.237, dtype=tf.float32)
b = tf.multiply(a, 3)
c = tf.add(a, b)

# 建立saver
saver = tf.train.Saver()

with tf.Session() as sess:
    # 载入模型数据
    # 注意路径名称与save函数一致,不是填写文件名
    saver.restore(sess, "d:/temp/model")

    print(sess.run(a))
    print(sess.run(b))
    print(sess.run(c))

仅载入计算图

import tensorflow as tf

# 载入模型图,不包含数据
# 注意路径指向为save时保存生成的.meta文件
saver = tf.train.import_meta_graph("d:/temp/model.meta")

with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.global_variables_initializer())

    # 取得计算图
    graph = tf.get_default_graph()

    # 取得tensor
    a_tensor = graph.get_tensor_by_name("Variable:0")
    b_tensor = graph.get_tensor_by_name("Mul:0")
    c_tensor = graph.get_tensor_by_name("Add:0")

    print(sess.run(a_tensor))  # 变量的初始值2
    print(sess.run(b_tensor))  # 首轮计算的数值6
    print(sess.run(c_tensor))  # 首轮计算的数值8

载入计算图和数据

import tensorflow as tf

# 载入模型图,不包含数据
# 注意路径指向为save时保存生成的.meta文件
saver = tf.train.import_meta_graph("d:/temp/model.meta")

with tf.Session() as sess:
    # 载入模型数据
    # 注意下方数据来源于同一模型,但是并非在一个save函数中保存,结果正确
    # 最好用同一个save生成的模型数据"d:/temp/model"
    saver.restore(sess, "d:/temp/untitled/loop.ckpt-46")

    # 取得计算图
    graph = tf.get_default_graph()

    # 取得tensor
    a_tensor = graph.get_tensor_by_name("Variable:0")
    b_tensor = graph.get_tensor_by_name("Mul:0")
    c_tensor = graph.get_tensor_by_name("Add:0")

    print(sess.run(a_tensor))
    print(sess.run(b_tensor))
    print(sess.run(c_tensor))

固化

在模型完成训练后,用于实际计算应用场景下时,可以将变量转化为常量,成为计算图的一部分。一方面方便保存与载入,另一方面提高运算效率。

固化保存代码如下:

import tensorflow as tf

x = tf.placeholder(dtype=tf.float32, shape=[])  # 定义输入
print(x)  # Tensor("Placeholder:0", shape=(), dtype=float32)
a = tf.Variable(2, dtype=tf.float32)
print(a)  # <tf.Variable 'Variable:0' shape=() dtype=float32_ref>
b = tf.multiply(a, 3)
print(b)  # Tensor("Mul:0", shape=(), dtype=float32)
c = tf.add(a, b)
print(c)  #  Tensor("Add:0", shape=(), dtype=float32)
y = tf.subtract(x, c)
print(y)  # Tensor("Sub:0", shape=(), dtype=float32)

with tf.Session() as sess:
    # 变量初始化
    sess.run(tf.global_variables_initializer())

    # 循环更新变量数值,模拟训练的过程
    for i in range(50):
        print(sess.run(c))
        sess.run(tf.assign_add(a, 1.5))

    # 生成当前计算图的GraphDef
    graph_def = tf.get_default_graph().as_graph_def()

    # 将当前图中变量全部转为常量
    # 注意,output_node_names传入输出节点名称列表,注意节点名称与tensor名称的区别
    graph_def_output = tf.graph_util.convert_variables_to_constants(sess=sess, input_graph_def=graph_def, output_node_names=["Placeholder", "Sub", "Variable"])

    # 下方代码生成保存文件constant_model,大多数情况下会加上.pb扩展名表明类型
    with tf.gfile.FastGFile("d:/temp/constant_model", "wb") as f:
        f.write(graph_def_output.SerializeToString())

固化载入代码如下:

import tensorflow as tf

with tf.Session() as sess:
    with tf.gfile.FastGFile("d:/temp/constant_model", "rb") as f:
        # 建立GraphDef,用于导入计算图
        graph_def = tf.GraphDef()

        # 从文件内容解析GraphDef内容
        graph_def.ParseFromString(f.read())

        # 从GraphDef加载计算图
        [x_tensor, y_tensor, a_tensor] = tf.import_graph_def(graph_def, return_elements=["Placeholder:0", "Sub:0", "Variable:0"])

        print(sess.run(a_tensor))
        print(sess.run(y_tensor, feed_dict={x_tensor: 1000}))


注意

关于CNN实际计算时去除训练中的batch维度

在载入CNN网络模型时,如果用于实际计算,而不是用于训练,则取出计算输出的tensor,在run()函数中feed_dict参数输入的placeholder,batch维度的数值设置为1,即输入tensor的shape为(1, 行数目, 列数目, 通道数目)

因为计算输出的tensor的计算中不包含需要从batch维度取值的操作。

猜你喜欢

转载自blog.csdn.net/botao_li/article/details/85112081