TensorFlow-模型的保存和调用(freeze方式)

版权声明:本文为博主原创文章,欢迎转载。 https://blog.csdn.net/samylee/article/details/85067155

TensorFlow-模型的保存和调用(freeze方式

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、tensorflow-gpu-1.4.0

一、基础知识

freeze:将ckpt的三个文件融合为一个文件,将variables转换为constant,文件更小,更易于移植

二、代码展示

1、保存模型

import tensorflow as tf
from tensorflow.python.framework import graph_util

# input
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')

# value b to be saved, like weight or bias
b = tf.Variable(1, name='b')

# output
xy = tf.multiply(x, y)
output = tf.add(xy, b, name='output') # name = 'output' must be added

inti = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(inti)

    # define graph to write
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output'])
 
    # write pb file
    with tf.gfile.FastGFile('model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    # test
    print(sess.run(output, feed_dict = {x: 10, y: 3}))

2、调用模型

import tensorflow as tf
from tensorflow.python.platform import gfile

sess = tf.Session()

# import graph, restore
with gfile.FastGFile('model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

# need initial  
sess.run(tf.global_variables_initializer())
 
# input
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')

# check variable b, like weight or bias
print(sess.run('b:0'))

# output
output = sess.graph.get_tensor_by_name('output:0')

# test
print(sess.run(output, feed_dict={input_x: 5, input_y: 5}))

任何问题请加唯一QQ2258205918(名称samylee)!

猜你喜欢

转载自blog.csdn.net/samylee/article/details/85067155