tf.app.run(main=None, argv=None)
TensorFlow
https://tensorflow.google.cn/
TensorFlow -> API
https://tensorflow.google.cn/versions
TensorFlow 1.x -> r1.4
https://github.com/tensorflow/docs/tree/r1.4/site/en/api_docs
tf.app.run(main=None, argv=None)
执行程序中的 main(_)
函数,并解析命令行参数。可选参数是 main=None, argv=None
,argv 是输入列表。
1. /usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic entry point script."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys as _sys
from tensorflow.python.platform import flags
from tensorflow.python.util.all_util import remove_undocumented
def _benchmark_tests_can_log_memory():
return True
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or _sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
_allowed_symbols = [
'run',
# Allowed submodule.
'flags',
]
remove_undocumented(__name__, _allowed_symbols)
flags_passthrough = f._parse_flags(args=args)
用来解析命令行参数的函数。
2. tf.app.run()
2.1 调用示例
if __name__ == '__main__':
tf.app.run()
Runs the program with an optional ‘main’ function and ‘argv’ list.
该 if 语句用来判断,该模块是正在被 import 还是被 SHELL 单独运行,被 SHELL 单独运行时,为 True。在该模块被 import 时,main 为模块中的一个函数。
2.2 函数原型
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
主函数中的 tf.app.run()
会调用 main(_)
,并传递参数,因此必须在 main 函数中设置一个参数。如果要更换 main 函数为 test(args)
,只需要在 tf.app.run()
中传入一个指定的函数名即可 tf.app.run(test)
。
如果你的代码中的入口函数不叫 main(_)
,而是一个其他名字的函数 test(args)
,则你应该这样写入口 tf.app.run(test)
。
2.3 代码解析
f = flags.FLAGS
flags = tf.app.flags
FLAGS = flags.FLAGS
tf.app.flags 用于接收命令行传递参数。
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
将传入的参数从第二个开始切片 copy 到 args 形成一个列表 (第一个为函数名),如果没有则传 None。
FLAGS = tf.app.flags.FLAGS
语句存在,表示输入已经解析。tf.app.run()
中 argv=None
,通过 args = argv[1:] if argv else None
语句,可知 args=None
(即不指定,后面会自动解析 command)。
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
f = flags.FLAGS
构造了解析器 f
用以解析 args
,flags_passthrough = f._parse_flags(args=args)
解析 args
列表或者 command 输入。args
列表为空,则解析 command 输入,返回的 flags_passthrough
内为无法解析的数据列表 (不包括文件名)。
main = main or sys.modules['__main__'].main
默认执行参数中指定的 main 函数。若 main=None
,则默认程序中 main(_)
函数。
在没有传入主函数参数时,就认为当前模块中已经有一个叫 main(_)
的主函数,将 main 赋给 main。在传入主函数参数时,将传入的当前模块自己定义的主函数传给等号左边的 main。
The first main
in right side of =
is the first argument of current function run(main=None, argv=None)
. While sys.modules['__main__']
means current running file (e.g. my_model.py).
sys.exit(main(sys.argv[:1] + flags_passthrough))
调用 main 函数,参数为文件名 + 无法解析数据的列表。
如果是从其它模块调用该模块程序,则不会运行 main 函数。如果直接运行该模块程序,则会运行 main 函数。如果此文件被其他文件 import 的时候,不会执行 main 函数。
3 example
3.1 命令行执行
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
tf.app.flags.DEFINE_integer('train_batch_size', 12, 'The number of images in each batch during training.')
tf.app.flags.DEFINE_boolean("is_train", True, "")
def main(_):
print("{}".format(FLAGS.train_batch_size))
print("{}".format(FLAGS.is_train))
print(_)
if __name__ == '__main__':
tf.app.run()
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py -h
usage: test.py [-h] [--train_batch_size TRAIN_BATCH_SIZE]
[--is_train [IS_TRAIN]] [--nois_train]
optional arguments:
-h, --help show this help message and exit
--train_batch_size TRAIN_BATCH_SIZE
The number of images in each batch during training.
--is_train [IS_TRAIN]
--nois_train
strong@foreverstrong:~/git_workspace/MonoGRNet$
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py
12
True
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64
64
True
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64 --is_train False
64
False
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64 --is_train False --gpus 0
64
False
['test.py', '--gpus', '0']
strong@foreverstrong:~/git_workspace/MonoGRNet$
3.2 PyCharm 执行
- 参数设置
- 调试过程
- 调试过程
- 调试过程
- 运行结果
/usr/bin/python2.7 /home/strong/git_workspace/MonoGRNet/test.py --train_batch_size 64 --is_train False --gpus 0
64
False
['/home/strong/git_workspace/MonoGRNet/test.py', '--gpus', '0']
Process finished with exit code 0