[TensorFlow深度学习入门]实战八·简便方法实现TensorFlow模型参数保存与加载(pb方式)

[TensorFlow深度学习入门]实战八·简便方法实现TensorFlow模型参数保存与加载(pb方式)

在上篇博文中,我们探索了TensorFlow模型参数保存与加载实现方法采用的是保存ckpt的方式。这篇博文我们会使用保存为pd格式文件来实现。
首先,我会在上篇博文基础上,实现由ckpt文件如何转换为pb文件,再去探索如何在训练时直接保存pb文件,最后是如何利用pb文件复现网络与参数完成应用预测功能。

  • ckpt文件转换pd文件

ckpt2pd文件代码:

import tensorflow as tf
pd_dir = "././Saver/test1/pb_dir/MyModel.pb"
with tf.Session() as sess:    
    #加载运算图
    saver = tf.train.import_meta_graph('./Saver/test1/checkpoint_dir/MyModel.meta')
    #加载参数
    saver.restore(sess,tf.train.latest_checkpoint('./Saver/test1/checkpoint_dir'))
    graph = tf.get_default_graph()
    out_graph = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,["in","out"])
    saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False)
    print("saver path: ",saver_path)

运行结果:

saver path:  ././Saver/test1/pb_dir/MyModel.pb
  • 训练保存pd文件

train文件代码

import tensorflow as tf

pd_dir = "././Saver/test2/pb_dir/MyModel.pb"



def main():
    x = tf.placeholder(dtype=tf.float32,shape=[None,2],name="in")
    #x = tf.constant([[1,2]],dtype=tf.float32)
    w1 = tf.get_variable("w1",dtype=tf.float32,initializer=tf.truncated_normal([2, 1], stddev=0.1))
    b1 = tf.get_variable("b1",initializer=tf.constant(.1, dtype=tf.float32, shape=[1, 1])) 

    y = tf.add(tf.matmul(x,w1),b1,name="out")
    
    with tf.Session() as sess:
        #获取计算图
        graph = tf.get_default_graph()
        #获取name和ops,这次代码并没有用到
        ret = graph.get_operations()
        r_names = []
        #获取name list
        for r in ret:
            r_names.append(r.name)

        srun = sess.run
        srun(tf.global_variables_initializer())
        print("y: ",srun(y,{x:[[1,2]]}))
        #存入输入与输出接口
        out_graph = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,["in","out"])
        saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False)

        
        print("saver path: ",saver_path)

if __name__ == "__main__":
    main()

运行结果:

y:  [[0.14729613]]
saver path:  ./././Saver/test2/pb_dir/MyModel.pb
  • pb文件复现网络与参数

restore文件代码

import tensorflow as tf
from saver1 import pd_dir

with tf.Session() as sess:
    #用上下文管理器打开pd文件    
    with open(pd_dir,"rb") as pd_flie:
        #获取图
        graph = tf.GraphDef()
        #获取参数
        graph.ParseFromString(pd_flie.read())
        #引入输入输出接口
        ins, outs = tf.import_graph_def(graph,return_elements=["in:0","out:0"])
        #进行预测
        print("y: ",sess.run(outs,{ins:[[1,2]]}))

运行结果:

y:  [[0.14729613]]

猜你喜欢

转载自blog.csdn.net/xiaosongshine/article/details/84756292