最近使用TensorFlow时,需要对数据(x, y)输入CNN网络的具体分支进行控制,有时x输入第1分支,y输入第2分支,有时需要反过来。起初,设置了一个placeholder,直接根据它的大小进行if…else的逻辑判断,发现生成的计算图只包含一种输入情况,网上查了后,才发现TensorFlow有这个功能,那就是使用tf.cond() 函数控制数据流向。
tf.cond() 函数的说明文档如下:
format:tf.cond(pred, fn1, fn2, name=None)
Return :either fn1() or fn2() based on the boolean predicate `pred`.
(注意:'fn1'和‘fn2’是两个函数)
当pred为True时,执行fn1;反之,执行fn2
下面是一个具体的例子:
import tensorflow as tf
a = tf.constant(2)
b = tf.constant(3)
x = tf.constant(4)
y = tf.constant(5)
A = tf.cond(x < y, lambda: tf.identity(a), lambda: tf.identity(b))
B = tf.add(A, a)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
summary_writer = tf.summary.FileWriter("logs", sess.graph)
with tf.Session() as session:
print(B.eval())
# prints "4"
在tf.cond()函数中,因为 成立,所以执行lambda: tf.identity(a),其中,lambda这是Python支持一种语法,它允许你快速定义单行的最小函数,类似与C语言中的宏,这些叫做lambda的函数,是从LISP借用来的,可以用在任何需要函数的地方。tf.identity()函数返回的是与a相同大小相同内容的tensor.
所以最后A=a=2,而B=A+a=4.
计算图在tensorboard的展示效果为:
Reference
[1] 使用if..else..出现的错误
[2] tf.cond()的用法, CSDN博客
[3] tensorflow python lambda, CSDN博客
[4] C语言的艺术:强大的宏, CSDN博客