不同图之间的分别调用,训练

不同图之间的分别调用,训练
import tensorflow as tf
import numpy as np
import os

tf.app.flags.DEFINE_string('checkpoints_dir3', os.path.abspath('./checkpoints/'), 'checkpoints save path.')
tf.app.flags.DEFINE_string('model_prefix3', 'travel_gan', 'model save prefix.')
# tf.app.flags.DEFINE_integer('epochs_gan', 100, 'train how many epochs.')
# tf.app.flags.DEFINE_integer('batch_size_gan', 64, 'batch size.') #tf定义了tf.app.flags,用于支持接受命令行传递参数,相当于接受argv
FLAGS = tf.app.flags.FLAGS
###################################################################################################
#训练过程
def train(): 
    train_steps = 100
    checkpoint_steps = 50
    checkpoint_dir = './'
    a = tf.Graph()
    with a.as_default():
        # print(os.path.join(FLAGS.checkpoints_dir2, FLAGS.model_prefix2))
        x = tf.placeholder(tf.float32, shape=[None, 1])
        x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
        y = 4 * x + 4
        
        w = tf.Variable(tf.random_normal([1], -1, 1),name="w")
        b = tf.Variable(tf.zeros([1]),name="b")
        y_predict = w * x + b
        
        
        loss = tf.reduce_mean(tf.square(y - y_predict))
        optimizer = tf.train.GradientDescentOptimizer(0.5)
        train = optimizer.minimize(loss)
        saver = tf.train.Saver()
       
        with tf.Session(graph=a) as sess:
            sess.run(tf.initialize_all_variables())
        
            for i in range(0,train_steps):
                sess.run(train, feed_dict={x: x_data})
                if (i + 1) % checkpoint_steps == 0:
    #                 saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
                    saver.save(sess, os.path.join(FLAGS.checkpoints_dir3, FLAGS.model_prefix3), global_step=i+1)
###############################################################################################################
#预测过程
def predict():
    checkpoint_dir="./"  
    c = tf.Graph()
    with c.as_default():
        x = tf.placeholder(tf.float32, shape=[1, 1], name = "3")#读取参数时,需要设定好模型图!
        w = tf.Variable(tf.random_normal([1], -1, 1),name="w")
        b = tf.Variable(tf.zeros([1]),name="b")
        x_data = np.array([1])
        y_predict = w * x + b
        saver = tf.train.Saver()
    #     saver = tf.train.import_meta_graph("C:/Users/sd\Desktop/travel/checkpoints/travel_gan/travel_ganmodel.ckpt.meta")
        with tf.Session(graph = c) as sess:
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir3)#最后一次训练的参数
            if ckpt and ckpt.model_checkpoint_path:  
                saver.restore(sess, checkpoint)#取最后训练出来的权重
    #             saver.restore(sess, "C:/Users/sd\Desktop/travel/checkpoints/travel_gan/travel_ganmodel.ckpt")        
                [result, b] = sess.run([y_predict, b], feed_dict={x: x_data.reshape(1,1)})
    #             result = sess.run(y_predict, feed_dict={x: result.reshape(1,1)})
    #             result = sess.run(y_predict, feed_dict={x: x_data.reshape(1,1)})
                print(result)
                print(b)
            else:
                pass
    #         print("weights:", sess.run(w))  
def main():
    train()#这两个函数只能一次运行一个,不能同时运行,原因还不知道,按道理是可以的
#     test1.predict()
    #tf.reset_default_graph()#重置图命令,很关键,否则就报错了
    predict()
    #tf.reset_default_graph()##重置图,就可以反复做预测了
    #predict()
    
main()
    
    



猜你喜欢

转载自blog.csdn.net/weixin_38145317/article/details/79387643