Tensorflow深度学习之二十八:tf.cond

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/DaVinciL/article/details/82496407

一、简介

def cond(pred, # 谓词,可以理解为判断条件
         true_fn=None, # 当谓词为真(True)时返回的函数
         false_fn=None, # 当谓词为假(False)时返回的函数
         strict=False, #
         name=None,
         fn1=None,
         fn2=None):

API注释:
Return true_fn() if the predicate pred is true else false_fn().

true_fn and false_fn both return lists of output tensors. true_fn and false_fn must have the same non-zero number and type of outputs.

Note that the conditional execution applies only to the operations defined in true_fn and false_fn. Consider the following simple program:

python
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

If x < y, the tf.add operation will be executed and tf.square operation will not be executed. Since z is needed for at least one branch of the cond, the tf.multiply operation is always executed, unconditionally.
Although this behavior is consistent with the dataflow model of TensorFlow, it has occasionally surprised some users who expected a lazier semantics.

Note that cond calls true_fn and false_fn exactly once (inside the call to cond, and not at all during Session.run()). cond stitches together the graph fragments created during the true_fn and false_fn calls with some additional graph nodes to ensure that the right branch gets executed depending on the value of pred.

tf.cond supports nested structures as implemented in tensorflow.python.util.nest. Both true_fn and false_fn must return the same (possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by true_fn and/or false_fn, they are implicitly unpacked to single values. This behavior is disabled by passing strict=True.

Google翻译:

如果谓词pred为真,则返回true_fn(),否则返回false_fn()

true_fnfalse_fn都返回输出张量列表。 true_fnfalse_fn必须具有相同的非零数字和输出类型。

请注意,条件执行仅适用于true_fnfalse_fn中定义的操作。考虑以下简单程序:

z = tf.multiply(a,b)
result = tf.cond(x <y,lambda:tf.add(x,z),lambda:tf.square(y))

如果x <y,将执行tf.add操作并且不执行tf.square操作。由于cond的至少一个分支需要z,所以总是无条件地执行tf.multiply操作。
虽然这种行为与TensorFlow的数据流模型一致,但它偶尔会让一些期望更加懒惰语义的用户感到惊讶。

注意cond只调用一次true_fnfalse_fncond的调用中,在Session.run()期间不调用)。 cond将在true_fnfalse_fn调用期间创建的图形片段与一些额外的图形节点拼接在一起,以确保根据pred的值执行正确的分支。

tf.cond支持在tensorflow.python.util.nest中实现的嵌套结构。 true_fnfalse_fn都必须返回列表,元组和/或命名元组的相同(可能是嵌套的)值结构。
单例列表和元组构成了对此的唯一例外:当由true_fn和/或false_fn返回时,它们被隐式解压缩为单个值。通过传递strict = True禁用此行为。

总结:该函数类似与if...else... 分支,当谓词判断为真时,调用前面一个函数,谓词判断为假时则调用后面一个函数。这在写程序时很有用,因为在TensorFlow中,我们需要先建立Graph,此时数据是不可知的,常规方法并不能直接判断,这里就提供了一个借口,可以在数据未知时进行判断。pred: A scalar determining whether to return the result of true_fn or
false_fn.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
strict: A boolean that enables/disables ‘strict’ mode; see above.
name: Optional name prefix for the returned tensors.

二、参数
   在实际的使用过程中,我们一般只需要使用以下参数即可。

参数
pred A scalar determining whether to return the result of true_fn or false_fn. 一个标量,或者说是一个判断条件,用以判断返回true_fn 或者 false_fn
true_fn The callable to be performed if pred is true. pred 为真时,返回的函数
false_fn The callable to be performed if pred is false. pred 为假时,返回的函数
strict A boolean that enables/disables ‘strict’ mode; see above. 一个bool值,表示是否使用’strict’模式,详见上
name Optional name prefix for the returned tensors. 名称,可选参数

三、代码

import tensorflow as tf
import numpy as np

x = tf.constant(2)
y = tf.constant(1)


def f1(): return tf.multiply(x, 17)


def f2(): return tf.add(y, 23)


r = tf.cond(tf.less(x, y), f1, f2)

with tf.Session() as sess:
    print(sess.run(r))

运行结果:因为2<1为False,执行f2,得到结果1+23=24

24
import tensorflow as tf
import numpy as np

x = tf.constant(2)
y = tf.constant(5) # 与前面程序的区别仅仅是y取值不同


def f1(): return tf.multiply(x, 17)


def f2(): return tf.add(y, 23)


r = tf.cond(tf.less(x, y), f1, f2)

with tf.Session() as sess:
    print(sess.run(r))

运行结果:因为2<5为True,这里执行f1,返回2*17=34。

34

为了方便,也可以使用lambda来定义函数。

# coding=utf-8
import tensorflow as tf
import numpy as np

a = tf.placeholder(dtype=tf.float32)

# 随便定义一些计算逻辑
b = tf.add(a, 32)

c = tf.add(a, 56)

res = tf.cond(a < 10, lambda: b + 10, lambda: c * 2)

with tf.Session() as sess:
    print(sess.run(res, feed_dict={a: 13}))

计算结果:

138.0

猜你喜欢

转载自blog.csdn.net/DaVinciL/article/details/82496407