tensorflow__第一章:命令行参数的设定(flags=tf.app.flags)

在执行main函数之前首先进行flags的解析,也就是说TensorFlow通过设置flags来传递tf.app.run()所需要的参数,我们可以直接在程序运行前初始化flags,也可以在运行程序的时候设置命令行参数来达到传参的目的。

tf.app.flags的使用

flags = tf.app.flags flags.DEFINE_integer("epoch", 1000, "Epoch to train [25]")

flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")

flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")

flags.DEFINE_integer("train_size", 256, "The size of train images [np.inf]")

flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")

flags.DEFINE_string("dataset", "mnist", "The name of dataset [celebA, mnist, lsun]")

flags.DEFINE_boolean("train", True, "True for training, False for testing [False]")

FLAGS = flags.FLAGS

在类的初始化时可以赋值给类中的成员,例如:

dcgan = DCGAN(

sess,

input_width=FLAGS.input_width,

input_height=FLAGS.input_height,

output_width=FLAGS.output_width,

output_height=FLAGS.output_height,

batch_size=FLAGS.batch_size,

sample_num=FLAGS.batch_size,

y_dim=10,

dataset_name=FLAGS.dataset,

input_fname_pattern=FLAGS.input_fname_pattern,

crop=FLAGS.crop,

checkpoint_dir=FLAGS.checkpoint_dir,

sample_dir=FLAGS.sample_dir)

命令行的命名格式:

#第一个是参数名称,第二个参数是默认值,第三个是参数描述

tf.app.flags.DEFINE_string('str_name', 'def_v_1',"descrip1")

tf.app.flags.DEFINE_integer('int_name', 10,"descript2")

tf.app.flags.DEFINE_boolean('bool_name', False, "descript3")

FLAGS = tf.app.flags.

FLAGS #必须带参数,否则:'TypeError: main() takes no arguments (1 given)';

main的参数名随意定义,无要求 def main(_): print(FLAGS.str_name) print(FLAGS.int_name) print(FLAGS.bool_name) if __name__ == '__main__':#避免出现import的时候调用main() tf.app.run() #执行main函数

猜你喜欢

转载自blog.csdn.net/chen_gong_ping/article/details/82377761