三、TensorFlow模型的保存和加载

1、模型的保存:

import tensorflow as tf
v1 = tf.Variable(1.0,dtype=tf.float32)
v2 = tf.Variable(2.0,dtype=tf.float32)

x = v1 + v2

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    result = sess.run(x)
    #将模型保存在model文件夹下
    saver.save(sess,'./model/test.model')

    print('result:{}'.format(result))

2、模型的加载(直接加载图)

import tensorflow as tf
saver = tf.train.import_meta_graph('./model/test.model.meta')
with tf.Session() as sess:
    saver.restore(sess,'./model/test.model')
    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

3、模型的加载(给定映射关系,主要用于不同开发之间模型的调用)

import tensorflow as tf
a = tf.Variable(5.0,dtype=tf.float32,name='a')
b = tf.Variable(6.0,dtype=tf.float32,name='b')

x = a + b

saver = tf.train.Saver({'v1':a,'v2':b})
with tf.Session() as sess:
    saver.restore(sess,'./model/test.model')
    print(sess.run([x]))

 

猜你喜欢

转载自www.cnblogs.com/allen-GC/p/10720423.html
0条评论
添加一条新回复