Tensorflow 存储模型

存储模型

import tensorflow as tf

v1 = tf.Variable(tf.random_normal([1,2]), name="v1")
v2 = tf.Variable(tf.random_normal([2,3]), name="v2")
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    print ("V1:",sess.run(v1))  
    print ("V2:",sess.run(v2))
    saver_path = saver.save(sess, "save/model.ckpt")
    print ("Model saved in file: ", saver_path) 

读取模型

import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1,2]), name="v1")
v2 = tf.Variable(tf.random_normal([2,3]), name="v2")
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, "save/model.ckpt")
    print ("V1:",sess.run(v1))  
    print ("V2:",sess.run(v2))
    print ("Model restored")

猜你喜欢

转载自blog.csdn.net/qq_41686130/article/details/95913455
今日推荐