【tensorflow】tf.cond() 控制数据流向

最近使用TensorFlow时,需要对数据(x, y)输入CNN网络的具体分支进行控制,有时x输入第1分支,y输入第2分支,有时需要反过来。起初,设置了一个placeholder,直接根据它的大小进行if…else的逻辑判断,发现生成的计算图只包含一种输入情况,网上查了后,才发现TensorFlow有这个功能,那就是使用tf.cond() 函数控制数据流向

tf.cond() 函数的说明文档如下:

format:tf.cond(pred, fn1, fn2, name=None)
Return :either fn1() or fn2() based on the boolean predicate `pred`.
(注意:'fn1'和‘fn2’是两个函数)

当pred为True时,执行fn1;反之,执行fn2

下面是一个具体的例子:


import tensorflow as tf
a = tf.constant(2)
b = tf.constant(3)
x = tf.constant(4)
y = tf.constant(5)
A = tf.cond(x < y, lambda: tf.identity(a), lambda: tf.identity(b))
B = tf.add(A, a)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
summary_writer = tf.summary.FileWriter("logs", sess.graph)

with tf.Session() as session:
    print(B.eval())

# prints "4"

在tf.cond()函数中,因为 x < y 成立,所以执行lambda: tf.identity(a),其中,lambda这是Python支持一种语法,它允许你快速定义单行的最小函数,类似与C语言中的宏,这些叫做lambda的函数,是从LISP借用来的,可以用在任何需要函数的地方。tf.identity()函数返回的是与a相同大小相同内容的tensor.

所以最后A=a=2,而B=A+a=4.

计算图在tensorboard的展示效果为:
这里写图片描述


Reference
[1] 使用if..else..出现的错误
[2] tf.cond()的用法, CSDN博客
[3] tensorflow python lambda, CSDN博客
[4] C语言的艺术:强大的宏, CSDN博客

猜你喜欢

转载自blog.csdn.net/yideqianfenzhiyi/article/details/79406122