tensorflow 9. 参数解析和经典入口函数tf.app.run

版权声明:本文为博主原创文章,转载请标明原始博文地址。 https://blog.csdn.net/yuanlulu/article/details/82955483

概述

本文总结两种参数解析的接口,一种是python的参数解析包自带的功能,使用时需要import argparse。另一类是tensorflow自带的功能,解析时import tensorflow就行了。

python中的参数解析

parse_known_args的例子

tensorflow下的一个例子
tensorflow/examples/tutorials/mnist/mnist_deep.py,入口代码如下:

import tensorflow as tf

FLAGS = None

......

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--data_dir', type=str,
                      default='/tmp/tensorflow/mnist/input_data',
                      help='Directory for storing input data')
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

parse_known_args功能解释

关于parse_args()的使用方法可以参考之前的博客:
python参数解析和日志记录

parse_known_args()与parse_args()功能类似,只是它只返回已知的选项,未知选项原样返回。

parse_known_args()在接受到多余的命令行参数时不报错,而是返回一个tuple类型的命名空间和一个保存着余下的命令行字符的list。

这样做的好处可以分层解析自己的选项,把剩下的参数传给tensorflow来解析。

另一个例子

下面的例子来自这个博客

import argparse 
parser = argparse.ArgumentParser() 
parser.add_argument( 
    '--flag_int', 
    type=float, 
    default=0.01, 
    help='flag_int.' 
) 
FLAGS, unparsed = parser.parse_known_args() 
print(FLAGS) 
print(unparsed)

执行输出:

$ python prog.py --flag_int 0.02 --double 0.03 a 1
Namespace(flag_int=0.02)
['--double', '0.03', 'a', '1']

argparse模块的FLAGS

上面第一个例子,FLAGS在文件开头定义,可以作为依据全局的变量使用。解析出的参数都存储在FLAGS里面。

这个FLAGS存储的是argparse模块解析出的参数。后面tensorflow也定义了自己的FLAGS,不要弄混了。

tf.app.run

第一个例子中,tf.app.run使用的方式是:

tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

这句话的功能是将程序名(sys.argv[0])和未知参数交给tf.app.run解析,最终传给main函数。

run函数的定义如下:

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""

  # Define help flags.
  _define_help_flags()

  # 解析已知参数
  argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)

  #如果没有传入main参数,则会默认执行main函数
  main = main or _sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  _sys.exit(main(argv)) #将参数传给main

调用tf.app.run如果不出入任何参数,实际会调用到main函数。
这种使用tf.app.run作为入口的方式非常常见。

tensorflow的FLAGS:tf.app.flags.FLAGS

上面run函数有一行解析参数的代码调用:

# Parse known flags.
argv = flags.FLAGS(_sys.argv if argv is None else argv, known_only=True)

意思就是,如果传入了argv则解析argv,否则解析进程的入口参数。

例子

例子1:没有指定参数格式

import argparse
import sys
import os
import tensorflow as tf


def main(args):
    print("I'm main")
    print('args{}'.format(args))

if __name__ == '__main__':
    tf.app.run()

执行结果:

> python .\app_run.py --user=test
I'm main
args['.\\app_run.py', '--user=test']

默认情况下main会被调用,由于没有指定参数格式,所以user不能不正确解析

例子2:指定参数格式

上一小节没有定义参数格式,这样tf是解析不了参数的。定义参数的格式如下:

# 字符串
tf.app.flags.DEFINE_string("param_name", "default_val", "description")
# 布尔类型
tf.app.flags.DEFINE_boolean("param_name", "default_val", "description")
tf.app.flags.DEFINE_bool("param_name", "default_val", "description")
# 浮点型
tf.app.flags.DEFINE_float("param_name", "default_val", "description")
# 整型
tf.app.flags.DEFINE_integer("param_name", "default_val", "description")

可以看到,如果没有传入参数,则会使用默认参数。

完整的代码如下:

# coding:utf-8

# 学习使用 tf.app.flags 使用,全局变量
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

# tf.app.flags.DEFINE_string("param_name", "default_val", "description")
tf.app.flags.DEFINE_string( "train_data_path", "/home/user", "training data dir" )
tf.app.flags.DEFINE_string( "log_dir", "./logs", " the log dir" )
tf.app.flags.DEFINE_integer( "max_sentence", 80, "max num of tokens per query" )
tf.app.flags.DEFINE_integer( "embedding_size", 50, "embedding size" )

tf.app.flags.DEFINE_float( "learning_rate", 0.001, "learning rate" )


def main( unused_argv ):
    train_data_path = FLAGS.train_data_path
    print( "train_data_path", train_data_path )
    log_dir = FLAGS.log_dir
    print( 'log_dir:{}'.format(log_dir) )
    max_sentence = FLAGS.max_sentence
    print( "max_sentence", max_sentence )
    embdeeing_size = FLAGS.embedding_size
    print( "embedding_size", embdeeing_size )
    learning_rate = FLAGS.learning_rate
    print( 'learning_rate:{}'.format(learning_rate))


# 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数
if __name__ == '__main__':
    tf.app.run()  # 解析命令行参数,调用main 函数 main(sys.argv)

我们只看到参数格式声明的代码,没看到调用解析的代码,这是因为这部分代码在tf.app.run()中。这样代码看起来会简洁很多。

输出

在不指定参数的情况下,参数被赋予默认值:

> python .\tf_argument.py
train_data_path /home/user
log_dir:./logs
max_sentence 80
embedding_size 50
learning_rate:0.001

在指定参数的情况下,参数被赋予传入值。以下两种传参方式等效。

# 方式1
> python .\tf_argument.py --train_data_path=./  --max_sentence=100  --embedding_size=100 --learning_rate=0.05
train_data_path ./
log_dir:./logs
max_sentence 100
embedding_size 100
learning_rate:0.05

# 方式2
> python .\tf_argument.py --train_data_path ./  --max_sentence 100  --embedding_size 100 --learning_rate 0.05
train_data_path ./
log_dir:./logs
max_sentence 100
embedding_size 100
learning_rate:0.05

总结

本文讨论了python的参数解析和tensorflow的参数解析。

在tensorflow项目代码中,推荐使用tensorflow的tf.app.flags.FLAGS来解析参数,并作为全局的配置信息。

不知为何,tensorflow的官方例程混用了这两种参数解析,目前我还没想明白。

注意:tf.app.flags.FLAGS是可以作为全局变量来用的。不同的文件不用相互import就可以使用。

参考资料

python参数解析和日志记录

TensorFlow 中 tf.app.flags.FLAGS 的用法介绍

argarse.ArgumentParser.parse_known_args()解析

猜你喜欢

转载自blog.csdn.net/yuanlulu/article/details/82955483