TensorFlow2.0教程30:使用tf.function和AutoGraph提高代码性能

  在TensorFlow 2.0中,默认情况下启用了急切执行。 对于用户而言直观且灵活(运行一次性操作更容易,更快),但这可能会牺牲性能和可部署性。

  要获得最佳性能并使模型可在任何地方部署,可以优先使用tf.function从程序中构建图。 因为有AutoGraph,可以使用tf.function构建高效性能的Python代码,但仍有一些陷阱需要警惕。

  今天我们就来介绍一下tensorflow2.0中的TF fuction和AutoGraph。

  下面的辅助程序代码,用于演示可能遇到的各种错误。

  import contextlib

  # 构建包含上下文管理器的函数,使其可以在with中使用

  @contextlib.contextmanager

  def assert_raises(error_class):

  try:

  yield

  except error_class as e:

  print('Caught expected exception \n {}: {}'.format(error_class, e))

  except Exception as e:

  print('Got unexpected exception \n {}: {}'.format(type(e), e))

  else:

  raise Exception('Expected {} to be raised but no error was raised!'.format(

  error_class))

  tf.function

  一个tf.function定义就像是一个核心TensorFlow操作:可以急切地执行它; 也可以在静态图中使用它; 且它具有梯度。

  # 类似一个tensorflow操作

  @tf.function

  def add(a, b):

  return a+b

  add(tf.ones([2,2]), tf.ones([2,2]))

  array([[2., 2.],

  [2., 2.]], dtype=float32)>

  # tf.function操作可以计算梯度

  @tf.function

  def add(a, b):

  return a+b

  v = tf.Variable(2.0)

  with tf.GradientTape() as tape:

  res = add(v, 1.0)

  tape.gradient(res, v)

  # 可以内嵌调用tf.function

  @tf.function

  def dense_layer(x, w, b):

  return add(tf.matmul(x, w), b)

  dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

  array([[3., 3.],

  [3., 3.],

  [3., 3.]], dtype=float32)>

  跟踪和多态

  Python的动态类型意味着可以使用各种参数类型调用函数,Python将在每个场景中执行不同的操作。

  另一方面,TensorFlow图需要静态dtypes和形状尺寸。tf.function通过在必要时回溯函数来生成正确的图结构来弥补这一差距。大多数使用的tf.function源于这种回归行为。

  我们可以使用不同类型的参数调用函数来查看正在发生的事情。

  # 函数的多态

  @tf.function

  def double(a):

  print('追踪变量:',a)

  return a + a

  print('结果:',double(tf.constant(1)))

  print()

  print('结果:',double(tf.constant(1.1)))

  print()

  print('结果:',double(tf.constant('c')))

  print()

  追踪变量: Tensor("a:0", shape=(), dtype=int32)

  结果: tf.Tensor(2, shape=(), dtype=int32)

  追踪变量: Tensor("a:0", shape=(), dtype=float32)

  结果: tf.Tensor(2.2, shape=(), dtype=float32)

  追踪变量: Tensor("a:0", shape=(), dtype=string)

  结果: tf.Tensor(b'cc', shape=(), dtype=string)

  控制参数类型:

  创建一个新的tf.function。tf.function确保单独的对象不共享追踪。

  使用该get_concrete_function方法获取特定追踪

  指定input_signature何时调用tf.function以确保仅构建一个功能图。

  print('构建许可的追踪')

  double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))

  print("执行追踪函数")

  print(double_strings(tf.constant("a")))

  print(double_strings(a=tf.constant("b")))

  print("使用不合法参数")

  with assert_raises(tf.errors.InvalidArgumentError):

  double_strings(tf.constant(1))

  构建许可的追踪

  追踪变量: Tensor("a:0", dtype=string)

  执行追踪函数

  tf.Tensor(b'aa', shape=(), dtype=string)

  tf.Tensor(b'bb', shape=(), dtype=string)

  使用不合法参数

  Caught expected exception

  : cannot compute __inference_double_98 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_98]

  @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))

  def next_collatz(x):

  print("Tracing with", x)

  return tf.where(tf.equal(x % 2, 0), x // 2, 3 * x + 1)

  print(next_collatz(tf.constant([1, 2])))

  # 只能输入1维向量

  with assert_raises(ValueError):

  next_collatz(tf.constant([[1, 2], [3, 4]]))

  Tracing with Tensor("x:0", shape=(None,), dtype=int32)

  tf.Tensor([4 1], shape=(2,), dtype=int32)

  Caught expected exception

  : Python inputs incompatible with input_signature: inputs ((

  array([[1, 2],

  [3, 4]], dtype=int32)>,)), input_signature ((TensorSpec(shape=(None,), dtype=tf.int32, name=None),))

  什么时候回溯?

  多态tf.function通过跟踪生成具体函数的缓存。缓存键实际上是从函数args和kwargs生成的键的元组。为tf.Tensor参数生成的关键是其形状和类型。为Python原语生成的密钥是它的值。对于所有其他Python类型,键都基于对象,id()以便为每个类的实例独立跟踪方法。将来,TensorFlow可以为Python对象添加更复杂的缓存,可以安全地转换为张量。

  使用Python参数还是Tensor参数?

  通常,Python的参数被用来控制超参数和图的结构-例如,num_layers=10或training=True或nonlinearity=‘relu’。因此,如果Python参数发生变化,那么必须回溯图。

  但是,Python参数可能不会用于控制图构造。在这些情况下,Python值的变化可能会触发不必要的回溯。举例来说,这个训练循环,AutoGraph将动态展开。尽管存在多条迹线,但生成的图实际上是相同的,因此这有点低效。

  def train_one_step():

  pass

  @tf.function

  def train(num_steps):

  print("追踪: num_steps = {}".format(num_steps))

  for _ in tf.range(num_steps):

  train_one_step()

  train(num_steps=10)

  train(num_steps=20)

  追踪: num_steps = 10

  追踪: num_steps = 20

  # 使用tensor,同类型不会重复追踪

  train(num_steps=tf.constant(10))

  train(num_steps=tf.constant(20))

  追踪: num_steps = Tensor("num_steps:0", shape=(), dtype=int32)

  # 使用tensor,类型不同才会有新的追踪,(前一个单元格已追踪int型,所以该处不追踪)

  train(num_steps=tf.constant(10, dtype=tf.int32))

  train(num_steps=tf.constant(20.6))

  追踪: num_steps = Tensor("num_steps:0", shape=(), dtype=float32)

  副作用 tf.function

  通常,Python副作用(如打印或变异对象)仅在跟踪期间发生。怎么能可靠地触发副作用tf.function呢?

  一般的经验法则是仅使用Python副作用来调试跟踪。但是,TensorFlow操作类似于tf.Variable.assign,tf.print和tf.summary是确保TensorFlow运行时,在每次调用时跟踪和执行代码的最佳方法。通常使用功能样式将产生最佳结果。

  tf.function函数中的print()被用于跟踪,所以要调试输出每次调用(副作用),就需要tf.function()

  @tf.function

  def f(x):

  print("追踪:", x)

  tf.print('执行:', x)

  f(1)

  f(1)

  f(2)

  追踪: 1

  执行: 1

  执行: 1

  追踪: 2

  执行: 2

  如果想在每次调用期间执行Python代码tf.function,可以使用tf.py_function。tf.py_function缺点是它不便携和高效,也不能在分布式(多GPU,TPU)设置中很好地工作。此外,由于tf.py_function必须连接到图,它将所有输入/输出转换为张量。

  external_list = []

  def side_effect(x):

  print('Python side effect')

  external_list.append(x)

  @tf.function

  def f(x):

  tf.py_function(side_effect, inp=[x], Tout=[])

  f(1)

  f(1)

  f(1)

  print(external_list)

  WARNING: Logging before flag parsing goes to stderr.

  W0609 06:41:05.048375 139792217777920 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32

  W0609 06:41:05.053524 139792217777920 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32

  W0609 06:41:05.056409 139792226170624 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32

  Python side effect

  Python side effect

  Python side effect

  [, , ]

  谨防Python状态

  许多Python功能(如生成器和迭代器)依赖于Python运行时来跟踪状态。 通常,虽然这些构造在Eager模式下按预期工作,但由于跟踪行为,tf.function内部可能会发生许多意外情况。

  external_var = tf.Variable(0)

  @tf.function

  def buggy_consume_next(iterator):

  external_var.assign_add(next(iterator))

  tf.print('external_var:', external_var)

  iterator = iter([0,1,2,3])

  buggy_consume_next(iterator)

  # 后面没有正常迭代,输出的都是第一个

  buggy_consume_next(iterator)

  buggy_consume_next(iterator)

  external_var: 0

  external_var: 0

  external_var: 0

  如果在tf.function中生成并完全使用了迭代器,那么它应该可以正常工作。但是,整个迭代器可能正在被跟踪,这可能导致一个巨大的图。如果正在训练一个表示为Python列表的大型内存数据集,那么这会生成一个非常大的图,并且tf.function不太可能产生加速。

  如果要迭代Python数据,最安全的方法是将其包装在tf.data.Dataset中并使用该for x in y惯用法。AutoGraph特别支持tf.data.Dataset 时安全地转换循环。

  def measure_graph_size(f, *args):

  g = f.get_concrete_function(*args).graph

  print("{}({}) 的图中包含了 {} 个节点".format(

  f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

  @tf.function

  def train(dataset):

  loss = tf.constant(0)

  for x, y in dataset:

  loss += tf.abs(y - x) # Some dummy computation.

  return loss

  small_data = [(1, 1)] * 2

  big_data = [(1, 1)] * 10

  measure_graph_size(train, small_data)

  measure_graph_size(train, big_data)

  measure_graph_size(train, tf.data.Dataset.from_generator(

  lambda: small_data, (tf.int32, tf.int32)))

  measure_graph_size(train, tf.data.Dataset.from_generator(

  lambda: big_data, (tf.int32, tf.int32)))

  train([(1, 1), (1, 1)]) 的图中包含了 8 个节点

  train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) 的图中包含了 32 个节点

  train(, ), types: (tf.int32, tf.int32)>) 的图中包含了 4 个节点

  train(, ), types: (tf.int32, tf.int32)>) 的图中包含了 4 个节点

  在数据集中包装Python / Numpy数据时,请注意tf.data.Dataset.from_generator与tf.data.Dataset.from_tensors。前者将数据保存在Python中并通过tf.py_function它获取性能影响,而后者将数据的副本捆绑为图中的一个大tf.constant()节点,这可能会对内存产生影响。

  通过TFRecordDataset / CsvDataset / etc从文件中读取数据。是最有效的数据处理方式,因为TensorFlow本身可以管理数据的异步加载和预取,而不必涉及Python。

  自动控制依赖项

  在一般数据流图上,作为编程模型的函数的一个非常吸引人的特性是函数可以为运行时提供有关代码的预期行为的更多信息。

  例如,当编写具有多个读取和写入相同变量的代码时,数据流图可能不会自然地编码最初预期的操作顺序。在tf.function,我们通过引用原始Python代码中的语句的执行顺序来解决执行顺序中的歧义。这样,有序状态操作的排序tf.function复制了Eager模式的语义。

  这意味着不需要添加手动控制依赖项; tf.function足够聪明,可以为代码添加最小的必要和充分的控制依赖关系,以便正确运行。

  # 按顺序自动执行

  a = tf.Variable(1.0)

  b = tf.Variable(2.0)

  @tf.function

  def f(x, y):

  a.assign(y * b)

  b.assign_add(x * a)

  return a + b

  f(1.0, 2.0)

  变量

  我们可以使用相同的想法来利用代码的预期执行顺序,使变量创建和利用变得非常容易tf.function。但是有一个非常重要的警告,即使用变量,可以编写在急切模式和图形模式下表现不同的代码。

  具体来说,每次调用创建一个新变量时都会发生这种情况。由于跟踪语义,tf.function每次调用都会重用相同的变量,但是eager模式会在每次调用时创建一个新变量。为防止出现此错误,tf.function如果检测到危险变量创建行为,则会引发错误。

  @tf.function

  def f(x):

  # tf.function会重复调用相同变量,而eager每次都会创建新的变量

  v = tf.Variable(1.0)

  v.assign_add(x)

  return v

  with assert_raises(ValueError):

  f(1.0)

  Caught expected exception

  : in converted code:

  :4 f *

  v = tf.Variable(1.0)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__

  return cls._variable_v2_call(*args, **kwargs)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call

  shape=shape)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:60 getter

  return captured_getter(captured_previous, **kwargs)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:364 invalid_creator_scope

  "tf.function-decorated function tried to create "

  ValueError: tf.function-decorated function tried to create variables on non-first call.

  不会报错的方法是

  v = tf.Variable(1.0) # 把变量拿到tf.function外面

  @tf.function

  def f(x):

  return v.assign_add(x)

  print(f(1.0)) # 2.0

  print(f(2.0)) # 4.0

  tf.Tensor(2.0, shape=(), dtype=float32)

  tf.Tensor(4.0, shape=(), dtype=float32)

  也可以在tf.function中创建变量,只要可以保证这些变量仅在第一次执行函数时创建。

  class C: pass

  obj = C(); obj.v = None

  @tf.function

  def g(x):

  if obj.v is None:

  obj.v = tf.Variable(1.0)

  return obj.v.assign_add(x)

  print(g(1.0)) # 2.0

  print(g(2.0)) # 4.0

  tf.Tensor(2.0, shape=(), dtype=float32)

  tf.Tensor(4.0, shape=(), dtype=float32)

  变量初始值设定项可以依赖于函数参数和其他变量的值。 我们可以使用与生成控制依赖关系相同的方法找出正确的初始化顺序。

  state = []

  @tf.function

  def fn(x):

  if not state:

  state.append(tf.Variable(2.0 * x))

  state.append(tf.Variable(state[0] * 3.0))

  return state[0] * x * state[1]

  print(fn(tf.constant(1.0)))

  print(fn(tf.constant(3.0)))

  tf.Tensor(12.0, shape=(), dtype=float32)

  tf.Tensor(36.0, shape=(), dtype=float32)

  使用AutoGraph

  该签名库完全集成tf.function,它将改写条件和循环依赖于张量在图形动态运行。

  tf.cond并且tf.while_loop继续使用tf.function,但是当以命令式样式编写时,具有控制流的代码通常更容易编写和理解。

  # 简单的循环

  @tf.function

  def f(x):

  # 直接用python中的while写循环

  while tf.reduce_sum(x) > 1:

  tf.print(x)

  x = tf.tanh(x)

  return x

  f(tf.random.uniform([5]))

  [0.829342961 0.858322263 0.900950909 0.851897 0.530384183]

  [0.680123031 0.695392191 0.716760576 0.692059278 0.485674709]

  [0.591599405 0.601434886 0.614898741 0.599303305 0.450776756]

  [0.53104496 0.538069844 0.547566235 0.536553681 0.422537297]

  [0.486179501 0.491525501 0.498693913 0.490374774 0.399065822]

  [0.451178908 0.455426365 0.461089343 0.454513818 0.379149348]

  [0.422867566 0.426349223 0.430971652 0.425602287 0.361968517]

  [0.399343461 0.402265817 0.406133026 0.401639521 0.346946776]

  [0.379387051 0.381885976 0.385184318 0.381350905 0.333665]

  [0.362175018 0.36434418 0.367201209 0.363880038 0.321810097]

  [0.347128421 0.349034756 0.351541221 0.348627061 0.311142713]

  [0.333826423 0.335519224 0.337741673 0.335157365 0.30147627]

  [0.321954757 0.323471278 0.325459719 0.323147237 0.292663]

  [0.311273336 0.312642276 0.314435244 0.312349856 0.284584]

  [0.301595032 0.302838922 0.304466605 0.302573323 0.277142316]

  [0.292771578 0.293908447 0.295394808 0.293665737 0.270258158]

  [0.284683794 0.285728157 0.287092626 0.285505235 0.263865024]

  [0.277234435 0.278198302 0.279456645 0.277992576 0.257907033]

  [0.270343572 0.271236718 0.272402078 0.271046132 0.25233686]

  [0.263944477 0.264775217 0.265858531 0.264597982 0.247114092]

  [0.257981181 0.258756459 0.259766966 0.258591145 0.242203966]

  [0.252406299 0.253132015 0.254077554 0.252977312 0.237576365]

  [0.24717927 0.247860536 0.248747766 0.247715324 0.233205199]

  [0.242265314 0.242906466 0.24374117 0.242769822 0.229067564]

  [0.237634286 0.238239139 0.239026278 0.238110229 0.225143358]

  [0.233259991 0.233831868 0.234575793 0.233709976 0.221414775]

  [0.229119495 0.229661271 0.230365857 0.229545817 0.217866093]

  [0.225192651 0.22570689 0.22637549 0.225597292 0.214483246]

  [0.221461684 0.221950635 0.222586185 0.221846417 0.211253688]

  [0.217910782 0.218376443 0.218981609 0.218277216 0.208166167]

  [0.214525893 0.214970052 0.215547174 0.214875415 0.205210552]

  [0.211294428 0.211718708 0.212269917 0.211628318 0.202377662]

  [0.208205134 0.208611 0.209138155 0.20852454 0.199659243]

  [0.205247864 0.205636591 0.206141427 0.2055538 0.197047815]

  [0.20241344 0.202786222 0.203270242 0.202706844 0.194536477]

  array([0.19969359, 0.2000515 , 0.2005161 , 0.19997531, 0.192119 ],

  dtype=float32)>

  print(f)

  可以检查代码签名生成。 但感觉就像阅读汇编语言一样。

  def f(x):

  while tf.reduce_sum(x) > 1:

  tf.print(x)

  x = tf.tanh(x)

  return x

  print(tf.autograph.to_code(f))

  def tf__f(x):

  do_return = False

  retval_ = ag__.UndefinedReturnValue()

  def loop_test(x_1):

  return ag__.converted_call('reduce_sum', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None) > 1

  def loop_body(x_1):

  ag__.converted_call('print', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None)

  x_1 = ag__.converted_call('tanh', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None)

  return x_1,

  x, = ag__.while_stmt(loop_test, loop_body, (x,))

  do_return = True

  retval_ = x

  cond = ag__.is_undefined_return(retval_)

  def get_state():

  return ()

  def set_state(_):

  pass

  def if_true():

  retval_ = None

  return retval_

  def if_false():

  return retval_

  retval_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state)

  return retval_

  AutoGraph:条件

  AutoGraph会将if语句转换为等效的tf.cond调用。

  如果条件是Tensor,则进行此替换。否则,在跟踪期间执行条件。

  # 测试

  def test_tf_cond(f, *args):

  # 获取图

  g = f.get_concrete_function(*args).graph

  if any(node.name=='cond' for node in g.as_graph_def().node):

  print("{}({}) 使用 tf.cond.".format(

  f.__name__, ', '.join(map(str, args))))

  else:

  print("{}({}) 正常执行.".format(

  f.__name__, ', '.join(map(str, args))))

  只有条件为tensor,才会使用tf.cond

  @tf.function

  def hyperparam_cond(x, training=True):

  if training:

  x = tf.nn.dropout(x, rate=0.5)

  return x

  @tf.function

  def maybe_tensor_cond(x):

  if x < 0:

  x = -x

  return x

  test_tf_cond(hyperparam_cond, tf.ones([1], dtype=tf.float32))

  test_tf_cond(maybe_tensor_cond, tf.constant(-1)) # 条件为tensor

  test_tf_cond(maybe_tensor_cond, -1)

  hyperparam_cond(tf.Tensor([1.], shape=(1,), dtype=float32)) 正常执行.

  maybe_tensor_cond(tf.Tensor(-1, shape=(), dtype=int32)) 使用 tf.cond.

  maybe_tensor_cond(-1) 正常执行.

  tf.cond有一些细微之处。 - 它的工作原理是跟踪条件的两边,然后根据条件在运行时选择适当的分支。跟踪双方可能导致意外执行Python代码 - 它要求如果一个分支创建下游使用的张量,另一个分支也必须创建该张量。

  @tf.function

  def f():

  x = tf.constant(0)

  if tf.constant(True):

  x = x + 1

  tf.print('执行,x:', x)

  print("Tracing `then` branch")

  else:

  x = x - 1

  tf.print('执行,x:', x) # 没有执行

  print("Tracing `else` branch") # 该分支虽然不执行但也被追踪

  return x

  f()

  Tracing `then` branch

  Tracing `else` branch

  执行,x: 1

  两个分支必须都定义x

  @tf.function

  def f():

  if tf.constant(True):

  x = tf.ones([3, 3])

  return x

  # 两个分支必须都定义x, 否则会抛出异常

  with assert_raises(ValueError):

  f()

  Caught expected exception

  : in converted code:

  :3 f *

  if tf.constant(True):

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:439 if_stmt

  return tf_if_stmt(cond, body, orelse, get_state, set_state)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:456 tf_if_stmt

  outputs, final_state = control_flow_ops.cond(cond, body, orelse)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:507 new_func

  return func(*args, **kwargs)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:1147 cond

  return cond_v2.cond_v2(pred, true_fn, false_fn, name)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:86 cond_v2

  op_return_value=pred)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:716 func_graph_from_py_func

  func_outputs = python_func(*func_args, **func_kwargs)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:486 wrapper

  outputs = func()

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:512 wrapper

  tuple(s.symbol_name for s in undefined)))

  ValueError: The following symbols must also be initialized in the else branch: ('x',). Alternatively, you may initialize them before the if statement.

  AutoGraph和循环

  AutoGraph有一些简单的转换循环规则。

  for:如果iterable是张量,则转换

  while:如果while条件取决于张量,则转换

  如果循环被转换,它将被动态展开tf.while_loop,或者在a的特殊情况下for x in tf.data.Dataset转换为tf.data.Dataset.reduce。

  如果未转换循环,则将静态展开

  # 测试

  def test_dynamically_unrolled(f, *args):

  g = f.get_concrete_function(*args).graph

  if any(node.name == 'while' for node in g.as_graph_def().node):

  print("{}({}) uses tf.while_loop.".format(

  f.__name__, ', '.join(map(str, args))))

  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):

  print("{}({}) uses tf.data.Dataset.reduce.".format(

  f.__name__, ', '.join(map(str, args))))

  else:

  print("{}({}) gets unrolled.".format(

  f.__name__, ', '.join(map(str, args))))

  @tf.function

  def for_in_range():

  x = 0

  for i in range(5):

  x += i

  return x

  @tf.function

  def for_in_tfrange():

  x = tf.constant(0, dtype=tf.int32)

  for i in tf.range(5): # 生成迭代的张量

  x += i

  return x

  @tf.function

  def for_in_tfdataset():

  x = tf.constant(0, dtype=tf.int64)

  for i in tf.data.Dataset.range(5):

  x += i

  return x

  test_dynamically_unrolled(for_in_range)

  test_dynamically_unrolled(for_in_tfrange)

  test_dynamically_unrolled(for_in_tfdataset)

  for_in_range() gets unrolled.

  for_in_tfrange() uses tf.while_loop.

  for_in_tfdataset() uses tf.data.Dataset.reduce.

  @tf.function

  def while_py_cond():

  x = 5

  while x > 0:

  x -= 1

  return x

  @tf.function

  def while_tf_cond():

  x = tf.constant(5)

  while x > 0: # while中的x为张量

  x -= 1

  return x

  test_dynamically_unrolled(while_py_cond)

  test_dynamically_unrolled(while_tf_cond)

  while_py_cond() gets unrolled.

  while_tf_cond() uses tf.while_loop.

  如果有一个break或早期的return子句依赖于张量,那么顶级条件或者iterable也应该是一个张量。

  @tf.function

  def buggy_while_py_true_tf_break(x):

  while True:

  if tf.equal(x, 0):

  break

  x -= 1

  return x

  @tf.function

  def while_tf_true_tf_break(x):

  while tf.constant(True): # 有break,顶级条件必须为张量

  if tf.equal(x, 0):

  break

  x -= 1

  return x

  with assert_raises(TypeError):

  test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)

  test_dynamically_unrolled(while_tf_true_tf_break, 5)

  Caught expected exception

  : in converted code:

  :3 buggy_while_py_true_tf_break *

  while True:无锡人流医院哪家好 http://www.ytsg029.com/

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:313 while_stmt

  return _py_while_stmt(test, body, init_state, opts)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:401 _py_while_stmt

  while test(*state):

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:698 __bool__

  raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "

  TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

  while_tf_true_tf_break(5) uses tf.while_loop.

  @tf.function

  def buggy_py_for_tf_break():

  x = 0

  for i in range(5):

  if tf.equal(i, 3):

  break

  x += i

  return x

  @tf.function

  def tf_for_tf_break():

  x = 0

  for i in tf.range(5): # 有break,顶级迭代器必须为张量

  if tf.equal(i, 3):

  break

  x += i

  return x

  with assert_raises(TypeError):

  test_dynamically_unrolled(buggy_py_for_tf_break)

  test_dynamically_unrolled(tf_for_tf_break)

  Caught expected exception

  : in converted code:

  :4 buggy_py_for_tf_break *

  for i in range(5):

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:110 for_stmt

  return _py_for_stmt(iter_, extra_test, body, init_state)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:117 _py_for_stmt

  if extra_test is not None and not extra_test(*state):

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:698 __bool__

  raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "

  TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

  tf_for_tf_break() uses tf.while_loop.

  为了累积动态展开循环的结果,需要使用tf.TensorArray。

  # 实现一个动态rnn

  batch_size = 32

  seq_len = 3

  feature_size=4

  # rnn步,输入与状态叠加

  def rnn_step(inputs, state):

  return inputs + state

  @tf.function

  def dynamic_rnn(rnn_step, input_data, initial_state):

  # [batch, time, features] -> [time, batch, features]

  input_data = tf.transpose(input_data, [1, 0, 2]) # 每个时间维度,都是整个batch数据喂入

  max_seq_len = input_data.shape[0]

  # 保存循环中的状态,必须使用tf.TensorArray

  states = tf.TensorArray(tf.float32, size=max_seq_len)

  state = initial_state

  # 迭代时间步

  for i in tf.range(max_seq_len):

  state = rnn_step(input_data[i], state)

  states = states.write(i, state)

  # 把 batch_size重新换到前面

  return tf.transpose(states.stack(), [1, 0, 2])

  dynamic_rnn(rnn_step,

  tf.random.uniform([batch_size, seq_len, feature_size]),

  tf.zeros([batch_size, feature_size]))

  array([[[0.42647886, 0.73600817, 0.10211909, 0.89989746],

  [0.772506 , 1.6853498 , 0.48793948, 1.4499462 ],

  [1.1096102 , 2.3388233 , 0.5920907 , 1.588302 ]],

  ...

  [[0.15579033, 0.4594922 , 0.17970431, 0.19183934],

  [0.19597077, 0.5362154 , 0.19988954, 0.38290274],

  [0.7524748 , 1.0519221 , 0.76595306, 0.5257962 ]]], dtype=float32)>

  与此同时tf.cond,tf.while_loop还带有一些细微之处。 - 由于循环可以执行0次,因此必须在循环上方初始化在while_loop下游使用的所有张量 - 所有循环变量的形状/ dtypes必须与每次迭代保持一致

  @tf.function

  def buggy_loop_var_uninitialized():

  for i in tf.range(3):

  x = i # 必须在循环上方初始化好x

  return x

  @tf.function

  def f():

  x = tf.constant(0)

  for i in tf.range(3):

  x = i

  return x

  with assert_raises(ValueError):

  buggy_loop_var_uninitialized()

  f()

  Caught expected exception

  : in converted code:

  :3 buggy_loop_var_uninitialized *

  for i in tf.range(3):

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:95 for_stmt

  return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:125 _known_len_tf_for_stmt

  _disallow_undefs_into_loop(*init_state)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:50 _disallow_undefs_into_loop

  tuple(s.symbol_name for s in undefined)))

  ValueError: TensorFlow requires that the following symbols must be defined before the loop: ('x',)

  循环时 变量的类型不能改变

  @tf.function

  def buggy_loop_type_changes():

  x = tf.constant(0, dtype=tf.float32)

  for i in tf.range(3): # Yields tensors of type tf.int32...

  x = i

  return x

  with assert_raises(tf.errors.InvalidArgumentError):

  buggy_loop_type_changes()

  Caught expected exception

  : Input 1 of node while/merge/_10 was passed int32 from while/next_iteration/_28:0 incompatible with expected float. [Op:__inference_buggy_loop_type_changes_2119]

  循环时变量形状也不能改变

  @tf.function

  def buggy_concat():

  x = tf.ones([0, 10])

  for i in tf.range(5):

  x = tf.concat([x, tf.ones([1, 10])], axis=0) # 循环时变量形状不能改变

  return x

  with assert_raises(ValueError):

  buggy_concat()

  @tf.function

  def concat_with_padding():

  x = tf.zeros([5, 10])

  for i in tf.range(5):

  x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)

  x.set_shape([5, 10])

  return x

  concat_with_padding()

  Caught expected exception

  : in converted code:

  :4 buggy_concat *

  for i in tf.range(5):

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:95 for_stmt

  return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:156 _known_len_tf_for_stmt

  opts=dict(maximum_iterations=n))

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:327 _tf_while_stmt

  retval = control_flow_ops.while_loop(test, body, init_state, **opts)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2646 while_loop

  return_same_structure=return_same_structure)

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:213 while_loop

  len_orig_loop_vars], expand_composites=True))

  /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:869 _check_shapes_compat

  "specify a less-specific shape." % (input_t.name, shape, t.shape))

  ValueError: Input tensor 'ones:0' enters the loop with shape (0, 10), but has shape (1, 10) after one iteration. To allow the shape to vary across iterations, use the `shape_invariants` argument of tf.while_loop to specify a less-specific shape.

  array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],

  [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],

  [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],

  [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],

  [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>

猜你喜欢

转载自www.cnblogs.com/gnz49/p/11592195.html
今日推荐