pb方式保存训练结构及模型

以往总是用saver保存checkpoint形式来保存训练结果,发现做预测时需要重构原来的网络结构与参数,现在用pb方式保存,好多了。且多次预测时不需要重置网络了。代码如下
 
 
import tensorflow as tf
import numpy as np
import os
import test1
import pickle
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.platform import gfile
tf.app.flags.DEFINE_string('checkpoints_dir3', os.path.abspath('./checkpoints/travel_gan/'), 'checkpoints save path.')
tf.app.flags.DEFINE_string('model_prefix3', 'travel_gan', 'model save prefix.')
FLAGS = tf.app.flags.FLAGS
###################################################################################################
#训练过程
def train(): 
    train_steps = 100
    checkpoint_steps = 50
    
#     checkpoint_dir = 'C:/Users/sd\Desktop/travel/checkpoints/travel_gan'   
    x = tf.placeholder(tf.float32, shape=[None, 1],name='input_0')#第一个输入变量
    x2 = tf.placeholder(tf.float32, shape=[None, 1],name='input_1')#第2个输入
    x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
    x2_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
    y = 4 * x *x2+ 4
    
    w = tf.Variable(tf.random_normal([1], -1, 1),name="w")
    b = tf.Variable(tf.zeros([1]),name="b")
    y_predict = tf.add(w*x*x2,b,name='out_0') 
    y_predict2 = tf.add(w*x*x2,b,name='out_1') 
    
    loss = tf.reduce_mean(tf.square(y - y_predict),name="out_loss")
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)
    saver = tf.train.Saver()
   
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())    
        for i in range(0,train_steps):
            sess.run(train, feed_dict={x: x_data,x2:x2_data})
        #graph = convert_variables_to_constants(sess, sess.graph_def, ["out_0","out_1"])#保存了最后一次的参数和模型
        #graph = (sess, sess.graph_def, ["out_0","out_1"])#保存了最后一次的参数和模型
        #tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
#         os.system("rm -rf /tmp/load")
        graph = convert_variables_to_constants(sess, sess.graph_def, ["out_0","out_1","out_loss"])
        tf.train.write_graph(graph, '.',"graph.pb", False)

###############################################################################################################
#预测过程,直接调用前面训练过程保存的模型,并且不用重构网络
def predict():
    with tf.Session() as sess:
        with open('./graph.pb', 'rb') as f: 
#         with open('./graph.pb', 'rb') as f: 
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read()) 
            a=1.0
            b=2.0
#             b=np.array([2.0])
#             output = tf.import_graph_def(graph_def, input_map={'input:0':a}, return_elements=['out:0'], name='a') 
            output = tf.import_graph_def(graph_def, input_map={'input_0':a,'input_1':b}, return_elements=['out_1:0'], name='a1')#输出两个 
#             output = tf.import_graph_def(graph_def, input_map={'input_0':a,'input_1':b}, return_elements=['out_0:0'], name='a1') #只输出一个
            print(sess.run(output))
def predict2():
    with gfile.FastGFile("./graph.pb",'rb') as f:
        tf.reset_default_graph()
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())        
        tf.import_graph_def(graph_def, name='')
        with tf.Session() as sess:
#             tf.reset_default_graph()
#             sess.run(tf.initialize_all_variables()) 
            input_x = sess.graph.get_tensor_by_name("input_0:0")
            input_x2 = sess.graph.get_tensor_by_name("input_1:0")
            output0 = sess.graph.get_tensor_by_name("out_0:0")
            output1 = sess.graph.get_tensor_by_name("out_1:0")
            output1oss = sess.graph.get_tensor_by_name("out_loss:0")
            print(output1oss)
#             output2 = sess.graph.get_tensor_by_name("out_0:0")
            a=np.array([1.0])
            b=np.array([2.0])
            result = sess.run([output0,output1oss], {input_x: a.reshape(-1,1),input_x2: b.reshape(-1,1)})

            print (result )
            
#             Const = sess.graph.get_tensor_by_name("Const:0")
#             print Const
#             output = sess.graph.get_operation_by_name("output")
#             print output
def main():
    train()#
#     predict()#pb的保存方式能减少很多麻烦,并且可以多次调用不报错
    predict2()
if __name__ == '__main__':
    
    main()  


猜你喜欢

转载自blog.csdn.net/weixin_38145317/article/details/79379637