Keras(九) tf.function函数转换、@tf.function函数转换

本文将介绍如下内容:

  • 使用tf.function和AutoGraph提高代码性能
  • 使用@tf.function进行函数转换
  • 展示tf.function转换后的代码
  • tf.Variable无法在函数内部定义

一,使用tf.function和AutoGraph提高代码性能

1,为方便测试,自定义激活函数
# 使用TF2的内置方法tf.function(),自定义scaled_elu激活函数
# tf.function and auto-graph.
def scaled_elu(z, scale=1.0, alpha=1.0):
    # z >= 0 ? scale * z : scale * alpha * tf.nn.elu(z)
    is_positive = tf.greater_equal(z, 0.0)
    return scale * tf.where(is_positive, z, alpha * tf.nn.elu(z))

print(scaled_elu(tf.constant(-3.)))
print(scaled_elu(tf.constant([-3., -2.5])))

# ----output----------
tf.Tensor(-0.95021296, shape=(), dtype=float32)
tf.Tensor([-0.95021296 -0.917915  ], shape=(2,), dtype=float32)
2,使用tf.function函数将python函数转化为tf函数
scaled_elu_tf = tf.function(scaled_elu)
print(scaled_elu_tf(tf.constant(-3.)))
print(scaled_elu_tf(tf.constant([-3., -2.5])))

#---output------
tf.Tensor(-0.95021296, shape=(), dtype=float32)
tf.Tensor([-0.95021296 -0.917915  ], shape=(2,), dtype=float32)
3,返回已经转化为tf函数的python原函数
print(scaled_elu_tf.python_function is scaled_elu)

#---output------
True
4,对比python函数和转换TF后的函数的性能
%timeit scaled_elu(tf.random.normal((1000, 1000)))
%timeit scaled_elu_tf(tf.random.normal((1000, 1000)))

#---output------
745 µs ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
483 µs ± 39.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

二,使用@tf.function进行函数转换

tf.function函数类似,也可以使用装饰器@tf.function进行函数转换

# 1 + 1/2 + 1/2^2 + ... + 1/2^n

@tf.function
def converge_to_2(n_iters):
    total = tf.constant(0.)
    increment = tf.constant(1.)
    for _ in range(n_iters):
        total += increment
        increment /= 2.0
    return total

print(converge_to_2(20))

#---output------
tf.Tensor(1.9999981, shape=(), dtype=float32)

三,展示tf.function转换后的代码

# 7,展示tf.function转换后的代码
def display_tf_code(func):
    code = tf.autograph.to_code(func)
    from IPython.display import display, Markdown
    display(Markdown('```python\n{}\n```'.format(code)))
display_tf_code(converge_to_2)

#---output------
<IPython.core.display.Markdown object>

四,tf.Variable无法在TF函数内部定义

tf.Variable是个稍微特殊的操作,因为没有办法确定构建图的时候函数调用了多少次,而tf.Variable只会被创建一次,这就有了冲突,所以tf.Variable在有@tf.function的时候只能放到外面。
去掉@tf.function,把Variable放到函数中,那么返回值应该是一样的。

import tensorflow as tf 

var = tf.Variable(0.)

@tf.function
def add_21():
    return var.assign_add(21) # += 

print(add_21())

#---output--------
tf.Tensor(21.0, shape=(), dtype=float32)

五,注意

  • tf.function不影响输出类型。
  • display_tf_code()中传入的函数如果带有@tf.function标注,则会报错。to_code函数的输入是module, class, method, function, traceback, frame, or code object。不能是tf function。
  • 用了tf.function的标注就只能使用tensorflow的操作。因为它会把整个函数优化成tensorflow的图,对于其他操作它无法做出优化。

猜你喜欢

转载自blog.csdn.net/TFATS/article/details/110531576