【译】Effective TensorFlow Chapter8——控制流操作:条件和循环

本文翻译自: 《Control flow operations: conditionals and loops》, 如有侵权请联系删除,仅限于学术交流,请勿商用。如有谬误,请联系指出。

当我们在构建一个复杂模型如RNN(循环神经网络)的时候,你可能需要通过条件和循环来控制操作流程。在这一节,我们介绍一些在TensorFlow中常用的控制流。

假设我们现在需要通过一个条件判断来决定我们是否相加还是相乘两个变量。这个可以通过调用tf.cond()简单实现,它表现出像python中if...else...相似的功能。

a = tf.constant(1)
b = tf.constant(2)

p = tf.constant(True)

x = tf.cond(p, lambda: a + b, lambda: a * b)

print(tf.Session().run(x))
复制代码

因为这个条件判断为True,所以这个输出应该是加法输出,也就是输出3。

在使用TensorFlow的过程中,大部分时间你都会使用大型的张量,并且在一个批次(a batch)中进行操作。一个与之相关的条件操作符是tf.where(),它需要提供一个条件判断,就和tf.cond()一样,但是tf.where()将会根据这个条件判断,在一个批次中选择输出,如:

a = tf.constant([1, 1])
b = tf.constant([2, 2])

p = tf.constant([True, False])

x = tf.where(p, a + b, a * b)

print(tf.Session().run(x))
复制代码

返回的结果是[3, 2]

另一个广泛使用的控制流操作是tf.while_loop()。它允许在TensorFlow中构建动态的循环,这个可以实现对一个序列的变量进行操作。让我们看看我们如何通过tf.while_loops函数生成一个斐波那契数列吧:

n = tf.constant(5)

def cond(i, a, b):
    return i < n

def body(i, a, b):
    return i + 1, b, a + b

i, a, b = tf.while_loop(cond, body, (2, 1, 1))

print(tf.Session().run(b))
复制代码

这段代码将会返回5。tf.while_loop()需要一个条件函数和一个循环体函数,除此之外还需初始化循环变量。这些循环变量在每一次循环体函数调用完之后都会被更新一次,直到这个条件返回False为止。

现在想象我们想要保存这个斐波那契序列,我们可能更新我们的循环体函数以纪录当前值的历史记录:

n = tf.constant(5)

def cond(i, a, b, c):
    return i < n

def body(i, a, b, c):
    return i + 1, b, a + b, tf.concat([c, [a + b]], 0)

i, a, b, c = tf.while_loop(cond, body, (2, 1, 1, tf.constant([1, 1])))

print(tf.Session().run(c))
复制代码

当你尝试运行这个程序的时候,TensorFlow将会“抱怨”说第四个循环变量的形状正在改变。所以你必须指明这件事后有意为之的:

i, a, b, c = tf.while_loop(
    cond, body, (2, 1, 1, tf.constant([1, 1])),
    shape_invariants=(tf.TensorShape([]),
                      tf.TensorShape([]),
                      tf.TensorShape([]),
                      tf.TensorShape([None])))
复制代码

这使得代码变得丑陋不堪,而且效率极低。请注意,我们正在构建许多我们不使用的中间张量。TensorFlow对这种增长式的数组,其实有一个更好的解决方案:tf.TensorArray。让我们用张量数组做同样的事情:

n = tf.constant(5)

c = tf.TensorArray(tf.int32, n)
c = c.write(0, 1)
c = c.write(1, 1)

def cond(i, a, b, c):
    return i < n

def body(i, a, b, c):
    c = c.write(i, a + b)
    return i + 1, b, a + b, c

i, a, b, c = tf.while_loop(cond, body, (2, 1, 1, c))

c = c.stack()

print(tf.Session().run(c))
复制代码

TensorFlow 中的while_loop和张量数组是构建复杂的循环神经网络(RNN)的基本工具。作为练习,你可以尝试使用tf.while_loops实现 beam search 。你可以再尝试使用张量数组提高效率吗?

猜你喜欢

转载自juejin.im/post/5c73e15f518825628c30f1de
今日推荐