Tensorflow:tf.app.run()与命令行参数解析

首先给出一段常见的代码:

if __name__ == '__main__':    
    tf.app.run()12

它是函数入口,通过处理flag解析,然后执行main函数(或者接下来提到的xxx())(最后含有tf.app.run()的文件,在此行之前肯定能找到def main(_)或者在tf.app.run(xxx())之前找到def xxx().)用主函数和命令行参数列表来运行程序

找到Tensorflow中关于上述函数run()的源码:

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or _sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  _sys.exit(main(_sys.argv[:1] + flags_passthrough))


_allowed_symbols = [
    'run',
    # Allowed submodule.
    'flags',
]

remove_undocumented(__name__, _allowed_symbols)

可以看到源码中的过程是首先加载flags的参数项,然后执行main函数。其中参数是使用tf.app.flags.FLAGS定义的。

关于tf.app.flags.FLAGS的使用:

# fila_name: temp.py
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('string', 'train', 'This is a string')
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'This is the rate in training')
tf.app.flags.DEFINE_boolean('flag', True, 'This is a flag')

print('string: ', FLAGS.string)
print('learning_rate: ', FLAGS.learning_rate)
print('flag: ', FLAGS.flag)

输出:

string:  train
learning_rate:  0.001
flag:  True

如果在命令行中执行python3 temp.py --help,得到输出:

usage: temp.py [-h] [--string STRING] [--learning_rate LEARNING_RATE]
               [--flag [FLAG]] [--noflag]

optional arguments:
  -h, --help            show this help message and exit
  --string STRING       This is a string
  --learning_rate LEARNING_RATE
                        This is the rate in training
  --flag [FLAG]         This is a flag
  --noflag

如果要对FLAGS的默认值进行修改,只要输入命令:

python3 temp.py --string 'test' --learning_rate 0.2 --flag False

联合使用

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('string', 'train', 'This is a string')
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'This is the rate in training')
tf.app.flags.DEFINE_boolean('flag', True, 'This is a flag')

def main(unuse_args):
    print('string: ', FLAGS.string)
    print('learning_rate: ', FLAGS.learning_rate)
    print('flag: ', FLAGS.flag)

if __name__ == '__main__':
    tf.app.run()

主函数中的tf.app.run()会调用main,并传递参数,因此必须在main函数中设置一个参数位置。如果要更换main名字,只需要在tf.app.run()中传入一个指定的函数名即可。

def test(args):
    # test
    ...
if __name__ == '__main__':
    tf.app.run(test)

os.path.basename(),返回path最后的文件名。若path以/或\结尾,那么就会返回空值。

path='D:\CSDN'

os.path.basename(path)=CSDN

猜你喜欢

转载自blog.csdn.net/chanbo8205/article/details/85309476