30天干掉tensorflow2.0-day13 Autograph的使用规范

AutoGraph的使用规范

有三种计算图的构建方式:静态计算图,动态计算图,以及Autograph。

TensorFlow 2.0主要使用的是动态计算图和Autograph。

动态计算图易于调试,编码效率较高,但执行效率偏低。

静态计算图执行效率很高,但较难调试。

而Autograph机制可以将动态图转换成静态计算图,兼收执行效率和编码效率之利。

当然Autograph机制能够转换的代码并不是没有任何约束的,有一些编码规范需要遵循,否则可能会转换失败或者不符合预期。

我们将着重介绍Autograph的编码规范和Autograph转换成静态图的原理。

并介绍使用tf.Module来更好地构建Autograph。

本篇我们介绍使用Autograph的编码规范。

一,Autograph编码规范总结

  • 1,被@tf.function修饰的函数应尽可能使用TensorFlow中的函数而不是Python中的其他函数。例如使用tf.print而不是print,使用tf.range而不是range,使用tf.constant(True)而不是True.

  • 2,避免在@tf.function修饰的函数内部定义tf.Variable.

  • 3,被@tf.function修饰的函数不可修改该函数外部的Python列表或字典等数据结构变量。

二,Autograph编码规范解析

1,被@tf.function修饰的函数应尽量使用TensorFlow中的函数而不是Python中的其他函数。

import numpy as np
import tensorflow as tf

@tf.function
def np_random():
    a = np.random.randn(3,3)
    tf.print(a)
    
@tf.function
def tf_random():
    a = tf.random.normal((3,3))
    tf.print(a)
#np_random每次执行都是一样的结果。
np_random()
np_random()
array([[-0.71624973,  1.29642527, -0.65452842],
       [-0.7035557 , -1.03091348,  0.11619214],
       [-0.32508337,  0.65632219, -1.27583452]])
array([[-0.71624973,  1.29642527, -0.65452842],
       [-0.7035557 , -1.03091348,  0.11619214],
       [-0.32508337,  0.65632219, -1.27583452]])
#tf_random每次执行都会有重新生成随机数。
tf_random()
tf_random()
[[-1.76797986 0.683240771 0.914823711]
 [1.38333535 1.08455276 0.00554183451]
 [0.875963 -0.579497099 -1.04468513]]
[[-0.738048911 1.65886068 0.181300119]
 [-1.24241471 -0.0364666954 -0.514938533]
 [-1.14512217 1.35526192 -0.667849422]]

2,避免在@tf.function修饰的函数内部定义tf.Variable.

# 避免在@tf.function修饰的函数内部定义tf.Variable.

x = tf.Variable(1.0,dtype=tf.float32)
@tf.function
def outer_var():
    x.assign_add(1.0)
    tf.print(x)
    return(x)

outer_var() 
outer_var()
2
3





<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
@tf.function
def inner_var():
    x = tf.Variable(1.0,dtype = tf.float32)
    x.assign_add(1.0)
    tf.print(x)
    return(x)

#执行将报错
inner_var()
inner_var()
WARNING:tensorflow:From D:\anaconda3\lib\site-packages\tensorflow_core\python\ops\resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.



---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-6-c95a7c3c1ddd> in <module>
      7 
      8 #执行将报错
----> 9 inner_var()
     10 inner_var()


D:\anaconda3\lib\site-packages\tensorflow_core\python\eager\def_function.py in __call__(self, *args, **kwds)
    566         xla_context.Exit()
    567     else:
--> 568       result = self._call(*args, **kwds)
    569 
    570     if tracing_count == self._get_tracing_count():


D:\anaconda3\lib\site-packages\tensorflow_core\python\eager\def_function.py in _call(self, *args, **kwds)
    630         # Lifting succeeded, so variables are initialized and we can run the
    631         # stateless function.
--> 632         return self._stateless_fn(*args, **kwds)
    633     else:
    634       canon_args, canon_kwds = \


D:\anaconda3\lib\site-packages\tensorflow_core\python\eager\function.py in __call__(self, *args, **kwargs)
   2360     """Calls a graph function specialized to the inputs."""
   2361     with self._lock:
-> 2362       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2363     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2364 


D:\anaconda3\lib\site-packages\tensorflow_core\python\eager\function.py in _maybe_define_function(self, args, kwargs)
   2701 
   2702       self._function_cache.missed.add(call_context_key)
-> 2703       graph_function = self._create_graph_function(args, kwargs)
   2704       self._function_cache.primary[cache_key] = graph_function
   2705       return graph_function, args, kwargs


D:\anaconda3\lib\site-packages\tensorflow_core\python\eager\function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2591             arg_names=arg_names,
   2592             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2593             capture_by_value=self._capture_by_value),
   2594         self._function_attributes,
   2595         # Tell the ConcreteFunction to clean up its graph once it goes out of


D:\anaconda3\lib\site-packages\tensorflow_core\python\framework\func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    976                                           converted_func)
    977 
--> 978       func_outputs = python_func(*func_args, **func_kwargs)
    979 
    980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,


D:\anaconda3\lib\site-packages\tensorflow_core\python\eager\def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 


D:\anaconda3\lib\site-packages\tensorflow_core\python\framework\func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise


ValueError: in converted code:

    <ipython-input-6-c95a7c3c1ddd>:3 inner_var  *
        x = tf.Variable(1.0,dtype = tf.float32)
    D:\anaconda3\lib\site-packages\tensorflow_core\python\ops\variables.py:260 __call__
        return cls._variable_v2_call(*args, **kwargs)
    D:\anaconda3\lib\site-packages\tensorflow_core\python\ops\variables.py:254 _variable_v2_call
        shape=shape)
    D:\anaconda3\lib\site-packages\tensorflow_core\python\ops\variables.py:65 getter
        return captured_getter(captured_previous, **kwargs)
    D:\anaconda3\lib\site-packages\tensorflow_core\python\eager\def_function.py:502 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

3,被@tf.function修饰的函数不可修改该函数外部的Python列表或字典等结构类型变量。

tensor_list = []

#@tf.function #加上这一行切换成Autograph结果将不符合预期!!!
def append_tensor(x):
    tensor_list.append(x)
    return tensor_list

append_tensor(tf.constant(5.0))
append_tensor(tf.constant(6.0))
print(tensor_list)

[<tf.Tensor: shape=(), dtype=float32, numpy=5.0>, <tf.Tensor: shape=(), dtype=float32, numpy=6.0>]
tensor_list = []

@tf.function #加上这一行切换成Autograph结果将不符合预期!!!
def append_tensor(x):
    tensor_list.append(x)
    return tensor_list


append_tensor(tf.constant(5.0))
append_tensor(tf.constant(6.0))
print(tensor_list)

[<tf.Tensor 'x:0' shape=() dtype=float32>]

原创文章 58 获赞 7 访问量 6202

猜你喜欢

转载自blog.csdn.net/Elenstone/article/details/105513857