tf.app.run(main=None, argv=None)

tf.app.run(main=None, argv=None)

TensorFlow
https://tensorflow.google.cn/

TensorFlow -> API
https://tensorflow.google.cn/versions

TensorFlow 1.x -> r1.4
https://github.com/tensorflow/docs/tree/r1.4/site/en/api_docs

tf.app.run(main=None, argv=None) 执行程序中的 main(_) 函数,并解析命令行参数。可选参数是 main=None, argv=None,argv 是输入列表。

1. /usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Generic entry point script."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys as _sys

from tensorflow.python.platform import flags
from tensorflow.python.util.all_util import remove_undocumented


def _benchmark_tests_can_log_memory():
  return True


def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or _sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  _sys.exit(main(_sys.argv[:1] + flags_passthrough))


_allowed_symbols = [
    'run',
    # Allowed submodule.
    'flags',
]

remove_undocumented(__name__, _allowed_symbols)

flags_passthrough = f._parse_flags(args=args) 用来解析命令行参数的函数。

2. tf.app.run()

2.1 调用示例

if __name__ == '__main__':
    tf.app.run()

Runs the program with an optional ‘main’ function and ‘argv’ list.
该 if 语句用来判断,该模块是正在被 import 还是被 SHELL 单独运行,被 SHELL 单独运行时,为 True。在该模块被 import 时,main 为模块中的一个函数。

2.2 函数原型

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""

主函数中的 tf.app.run() 会调用 main(_),并传递参数,因此必须在 main 函数中设置一个参数。如果要更换 main 函数为 test(args),只需要在 tf.app.run() 中传入一个指定的函数名即可 tf.app.run(test)
如果你的代码中的入口函数不叫 main(_),而是一个其他名字的函数 test(args),则你应该这样写入口 tf.app.run(test)

2.3 代码解析

f = flags.FLAGS
flags = tf.app.flags
FLAGS = flags.FLAGS

tf.app.flags 用于接收命令行传递参数。

# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None

将传入的参数从第二个开始切片 copy 到 args 形成一个列表 (第一个为函数名),如果没有则传 None。
FLAGS = tf.app.flags.FLAGS 语句存在,表示输入已经解析。tf.app.run()argv=None,通过 args = argv[1:] if argv else None 语句,可知 args=None (即不指定,后面会自动解析 command)。

# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access

f = flags.FLAGS 构造了解析器 f 用以解析 argsflags_passthrough = f._parse_flags(args=args) 解析 args 列表或者 command 输入。args 列表为空,则解析 command 输入,返回的 flags_passthrough 内为无法解析的数据列表 (不包括文件名)。

main = main or sys.modules['__main__'].main 默认执行参数中指定的 main 函数。若 main=None,则默认程序中 main(_) 函数。
在没有传入主函数参数时,就认为当前模块中已经有一个叫 main(_) 的主函数,将 main 赋给 main。在传入主函数参数时,将传入的当前模块自己定义的主函数传给等号左边的 main。
The first main in right side of = is the first argument of current function run(main=None, argv=None). While sys.modules['__main__'] means current running file (e.g. my_model.py).

sys.exit(main(sys.argv[:1] + flags_passthrough)) 调用 main 函数,参数为文件名 + 无法解析数据的列表。

如果是从其它模块调用该模块程序,则不会运行 main 函数。如果直接运行该模块程序,则会运行 main 函数。如果此文件被其他文件 import 的时候,不会执行 main 函数。

3 example

3.1 命令行执行

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import tensorflow as tf

flags = tf.app.flags
FLAGS = flags.FLAGS

tf.app.flags.DEFINE_integer('train_batch_size', 12, 'The number of images in each batch during training.')
tf.app.flags.DEFINE_boolean("is_train", True, "")


def main(_):
    print("{}".format(FLAGS.train_batch_size))
    print("{}".format(FLAGS.is_train))
    print(_)


if __name__ == '__main__':
    tf.app.run()
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py -h
usage: test.py [-h] [--train_batch_size TRAIN_BATCH_SIZE]
               [--is_train [IS_TRAIN]] [--nois_train]

optional arguments:
  -h, --help            show this help message and exit
  --train_batch_size TRAIN_BATCH_SIZE
                        The number of images in each batch during training.
  --is_train [IS_TRAIN]
  --nois_train
strong@foreverstrong:~/git_workspace/MonoGRNet$ 
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py
12
True
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$ 
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64
64
True
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$ 
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64 --is_train False
64
False
['test.py']
strong@foreverstrong:~/git_workspace/MonoGRNet$ 
strong@foreverstrong:~/git_workspace/MonoGRNet$ python test.py --train_batch_size 64 --is_train False --gpus 0
64
False
['test.py', '--gpus', '0']
strong@foreverstrong:~/git_workspace/MonoGRNet$

3.2 PyCharm 执行

  • 参数设置

在这里插入图片描述

  • 调试过程

在这里插入图片描述

  • 调试过程

在这里插入图片描述

  • 调试过程

在这里插入图片描述

  • 运行结果
/usr/bin/python2.7 /home/strong/git_workspace/MonoGRNet/test.py --train_batch_size 64 --is_train False --gpus 0
64
False
['/home/strong/git_workspace/MonoGRNet/test.py', '--gpus', '0']

Process finished with exit code 0

猜你喜欢

转载自blog.csdn.net/chengyq116/article/details/93389921