不同图之间的分别调用,训练
import tensorflow as tf import numpy as np import os tf.app.flags.DEFINE_string('checkpoints_dir3', os.path.abspath('./checkpoints/'), 'checkpoints save path.') tf.app.flags.DEFINE_string('model_prefix3', 'travel_gan', 'model save prefix.') # tf.app.flags.DEFINE_integer('epochs_gan', 100, 'train how many epochs.') # tf.app.flags.DEFINE_integer('batch_size_gan', 64, 'batch size.') #tf定义了tf.app.flags,用于支持接受命令行传递参数,相当于接受argv FLAGS = tf.app.flags.FLAGS ################################################################################################### #训练过程 def train(): train_steps = 100 checkpoint_steps = 50 checkpoint_dir = './' a = tf.Graph() with a.as_default(): # print(os.path.join(FLAGS.checkpoints_dir2, FLAGS.model_prefix2)) x = tf.placeholder(tf.float32, shape=[None, 1]) x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) y = 4 * x + 4 w = tf.Variable(tf.random_normal([1], -1, 1),name="w") b = tf.Variable(tf.zeros([1]),name="b") y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) saver = tf.train.Saver() with tf.Session(graph=a) as sess: sess.run(tf.initialize_all_variables()) for i in range(0,train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: # saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) saver.save(sess, os.path.join(FLAGS.checkpoints_dir3, FLAGS.model_prefix3), global_step=i+1) ############################################################################################################### #预测过程 def predict(): checkpoint_dir="./" c = tf.Graph() with c.as_default(): x = tf.placeholder(tf.float32, shape=[1, 1], name = "3")#读取参数时,需要设定好模型图! w = tf.Variable(tf.random_normal([1], -1, 1),name="w") b = tf.Variable(tf.zeros([1]),name="b") x_data = np.array([1]) y_predict = w * x + b saver = tf.train.Saver() # saver = tf.train.import_meta_graph("C:/Users/sd\Desktop/travel/checkpoints/travel_gan/travel_ganmodel.ckpt.meta") with tf.Session(graph = c) as sess: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir3)#最后一次训练的参数 if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, checkpoint)#取最后训练出来的权重 # saver.restore(sess, "C:/Users/sd\Desktop/travel/checkpoints/travel_gan/travel_ganmodel.ckpt") [result, b] = sess.run([y_predict, b], feed_dict={x: x_data.reshape(1,1)}) # result = sess.run(y_predict, feed_dict={x: result.reshape(1,1)}) # result = sess.run(y_predict, feed_dict={x: x_data.reshape(1,1)}) print(result) print(b) else: pass # print("weights:", sess.run(w)) def main(): train()#这两个函数只能一次运行一个,不能同时运行,原因还不知道,按道理是可以的 # test1.predict() #tf.reset_default_graph()#重置图命令,很关键,否则就报错了 predict() #tf.reset_default_graph()##重置图,就可以反复做预测了 #predict() main()