[TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

[TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

TensorFlow模型训练的好网络参数如果想重复高效利用,模型参数保存与加载是必须掌握的模块。本文提供一种简单容易理解的方式来实现上述功能。参考博客地址
备注:
本文采用的是ckpt保存方式,在下篇博文中介绍更加常用的pb保存方式,包括ckpt文件如何转换的pb文件,和如何直接保存问pb文件,感兴趣可以去看看。

  • 模型保存

代码:

import tensorflow as tf

x = tf.placeholder(dtype=tf.float32,shape=[None,2],name="in")
w1 = tf.get_variable("w1",initializer=tf.truncated_normal([2, 1], stddev=0.1))
b1 = tf.get_variable("b1",initializer=tf.constant(.1, dtype=tf.float32, shape=[1, 1])) 

y = tf.add(tf.matmul(x,w1),b1,name="out")
saver = tf.train.Saver()
with tf.Session() as sess:
    srun = sess.run
    srun(tf.global_variables_initializer())
    print("y: ",srun(y,{x:[[1,2]]}))
    #保存模型与参数
    saver_path = saver.save(sess, './Saver/test1/checkpoint_dir/MyModel')
    print("saver path: ",saver_path)

运行结果:

y:  [[0.26085645]]
saver path:  ./Saver/test1/checkpoint_dir/MyModel
  • 模型恢复

代码:

import tensorflow as tf

with tf.Session() as sess:    
    #加载运算图
    saver = tf.train.import_meta_graph('./Saver/test1/checkpoint_dir/MyModel.meta')
    #加载参数
    saver.restore(sess,tf.train.latest_checkpoint('./Saver/test1/checkpoint_dir'))
    graph = tf.get_default_graph()
    #导入输入接口
    x = graph.get_tensor_by_name("in:0")
    #导入输出接口
    y = graph.get_tensor_by_name("out:0")
    #进行预测
    print("y: ",sess.run(y,{x:[[1,2]]}))

运行结果:

y:  [[0.26085645]]
  • 结论

经过测试我们发现,当我们以相同的输入值去预测结果时,通过刚训练完成的网络与通过恢复的模型结果相同,验证了功能的正确性。

猜你喜欢

转载自blog.csdn.net/xiaosongshine/article/details/84723269