概述
本文总结两种参数解析的接口,一种是python的参数解析包自带的功能,使用时需要import argparse。另一类是tensorflow自带的功能,解析时import tensorflow就行了。
python中的参数解析
parse_known_args的例子
tensorflow下的一个例子
tensorflow/examples/tutorials/mnist/mnist_deep.py,入口代码如下:
import tensorflow as tf
FLAGS = None
......
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
parse_known_args功能解释
关于parse_args()的使用方法可以参考之前的博客:
python参数解析和日志记录
parse_known_args()与parse_args()功能类似,只是它只返回已知的选项,未知选项原样返回。
parse_known_args()在接受到多余的命令行参数时不报错,而是返回一个tuple类型的命名空间和一个保存着余下的命令行字符的list。
这样做的好处可以分层解析自己的选项,把剩下的参数传给tensorflow来解析。
另一个例子
下面的例子来自这个博客:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--flag_int',
type=float,
default=0.01,
help='flag_int.'
)
FLAGS, unparsed = parser.parse_known_args()
print(FLAGS)
print(unparsed)
执行输出:
$ python prog.py --flag_int 0.02 --double 0.03 a 1
Namespace(flag_int=0.02)
['--double', '0.03', 'a', '1']
argparse模块的FLAGS
上面第一个例子,FLAGS在文件开头定义,可以作为依据全局的变量使用。解析出的参数都存储在FLAGS里面。
这个FLAGS存储的是argparse模块解析出的参数。后面tensorflow也定义了自己的FLAGS,不要弄混了。
tf.app.run
第一个例子中,tf.app.run使用的方式是:
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
这句话的功能是将程序名(sys.argv[0])和未知参数交给tf.app.run解析,最终传给main函数。
run函数的定义如下:
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
# Define help flags.
_define_help_flags()
# 解析已知参数
argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)
#如果没有传入main参数,则会默认执行main函数
main = main or _sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
_sys.exit(main(argv)) #将参数传给main
调用tf.app.run如果不出入任何参数,实际会调用到main函数。
这种使用tf.app.run作为入口的方式非常常见。
tensorflow的FLAGS:tf.app.flags.FLAGS
上面run函数有一行解析参数的代码调用:
# Parse known flags.
argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)
意思就是,如果传入了argv则解析argv,否则解析进程的入口参数。
例子
例子1:没有指定参数格式
import argparse
import sys
import os
import tensorflow as tf
def main(args):
print("I'm main")
print('args{}'.format(args))
if __name__ == '__main__':
tf.app.run()
执行结果:
> python .\app_run.py --user=test
I'm main
args['.\\app_run.py', '--user=test']
默认情况下main会被调用,由于没有指定参数格式,所以user不能不正确解析
例子2:指定参数格式
上一小节没有定义参数格式,这样tf是解析不了参数的。定义参数的格式如下:
# 字符串
tf.app.flags.DEFINE_string("param_name", "default_val", "description")
# 布尔类型
tf.app.flags.DEFINE_boolean("param_name", "default_val", "description")
tf.app.flags.DEFINE_bool("param_name", "default_val", "description")
# 浮点型
tf.app.flags.DEFINE_float("param_name", "default_val", "description")
# 整型
tf.app.flags.DEFINE_integer("param_name", "default_val", "description")
可以看到,如果没有传入参数,则会使用默认参数。
完整的代码如下:
# coding:utf-8
# 学习使用 tf.app.flags 使用,全局变量
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string( "train_data_path", "/home/user", "training data dir" )
tf.app.flags.DEFINE_string( "log_dir", "./logs", " the log dir" )
tf.app.flags.DEFINE_integer( "max_sentence", 80, "max num of tokens per query" )
tf.app.flags.DEFINE_integer( "embedding_size", 50, "embedding size" )
tf.app.flags.DEFINE_float( "learning_rate", 0.001, "learning rate" )
def main( unused_argv ):
train_data_path = FLAGS.train_data_path
print( "train_data_path", train_data_path )
log_dir = FLAGS.log_dir
print( 'log_dir:{}'.format(log_dir) )
max_sentence = FLAGS.max_sentence
print( "max_sentence", max_sentence )
embdeeing_size = FLAGS.embedding_size
print( "embedding_size", embdeeing_size )
learning_rate = FLAGS.learning_rate
print( 'learning_rate:{}'.format(learning_rate))
# 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数
if __name__ == '__main__':
tf.app.run() # 解析命令行参数,调用main 函数 main(sys.argv)
我们只看到参数格式声明的代码,没看到调用解析的代码,这是因为这部分代码在tf.app.run()中。这样代码看起来会简洁很多。
输出
在不指定参数的情况下,参数被赋予默认值:
> python .\tf_argument.py
train_data_path /home/user
log_dir:./logs
max_sentence 80
embedding_size 50
learning_rate:0.001
在指定参数的情况下,参数被赋予传入值。以下两种传参方式等效。
# 方式1
> python .\tf_argument.py --train_data_path=./ --max_sentence=100 --embedding_size=100 --learning_rate=0.05
train_data_path ./
log_dir:./logs
max_sentence 100
embedding_size 100
learning_rate:0.05
# 方式2
> python .\tf_argument.py --train_data_path ./ --max_sentence 100 --embedding_size 100 --learning_rate 0.05
train_data_path ./
log_dir:./logs
max_sentence 100
embedding_size 100
learning_rate:0.05
总结
本文讨论了python的参数解析和tensorflow的参数解析。
在tensorflow项目代码中,推荐使用tensorflow的tf.app.flags.FLAGS来解析参数,并作为全局的配置信息。
不知为何,tensorflow的官方例程混用了这两种参数解析,目前我还没想明白。
注意:tf.app.flags.FLAGS是可以作为全局变量来用的。不同的文件不用相互import就可以使用。