跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

save =  tf.train.Saver()

通过save. save() 实现数据的加载

通过save.restore() 实现数据的导出

第一步: 数据的载入

import tensorflow as tf

#创建变量
v1 = tf.Variable(tf.random_normal([1, 2], name='v1'))
v2 = tf.Variable(tf.random_normal([2, 3], name='v2'))
#初始化变量
init_op = tf.global_variables_initializer()
#构建训练模型的保存
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    print('V1:', sess.run(v1))
    print('V2:', sess.run(v2))
    # saver.save(保存内容, 保存路径)
    saver_path = saver.save(sess, 'save/model.ckpt')
    print('Model saved in file:', saver_path)

第二步: 数据的导出

import tensorflow as tf
# v1,v2的设定,主要是看看输出的v1是哪个v1 v1
= tf.Variable(tf.random_normal([1, 2]), name='v1') v2 = tf.Variable(tf.random_normal([2, 3]), name='v2') # 构建保存模型 saver = tf.train.Saver() with tf.Session() as sess: # 重新加载模型(重新赋予名字, 加载的路径) saver.restore(sess, 'save/model.ckpt') print('V1:', sess.run(v1)) print('V2:', sess.run(v2)) print('Model restored')

猜你喜欢

转载自www.cnblogs.com/my-love-is-python/p/9570286.html