关于tf.cond分支执行与否的问题

虽然TensorFlow的Graph表面也有流程控制,但和普通程序的if还是差别挺大的,比如下面的代码: 

import tensorflow as tf
X1 = tf.Variable(0.)
cond_value = tf.Variable(False)
assign_1 = tf.assign(X1, 1.)
assign_2 = tf.assign(X1, 2.)
cond_result = tf.cond(cond_value, lambda: assign_1, lambda: assign_2)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(sess.run(cond_result))

通过一个tf.cond来决定是个X1赋值为1还是2。因为cond_value是False,所以期望X1的值是2。但实际运行结果却是1。

原因吗?要记住TensorFlow决定是否执行某个OP是根据依赖关系。

这个例子说明cond_value是依赖于assign_1,  assign_2的,所以不管cond_value的值是多少,assign_1,  assign_2都会被执行,只是cond_value会决定最终使用哪一个的结果。这就带来的两个问题:

  1. 使用tf.assign的时候,因为并不关注返回的值(返回值都是X1的地址,不管走哪个分支,返回值都是一样的),所以这个条件流程失效。
  2. 当我们想使用条件判断来规避一些异常值得时候,比如当索引tensor的index超出tensor的范围的时候,就走一个异常分支。但事实是因为不管cond_value是啥,两个分支都会被执行,所以还是会报错。
import tensorflow as tf
data=tf.constant([1])
cond_value = tf.Variable(False)
assign_1 = data[1]
assign_2 = data[0]
cond_result = tf.cond(cond_value, lambda: assign_1, lambda: assign_2)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(sess.run(cond_result))

这样写就会报索引值超出范围的错误。即使cond_value为False,应该只执行后面那个分支。

不过这个问题也有绕开的方法,就是把tf.assign封装到独立的函数里面。

import tensorflow as tf
X1 = tf.Variable(0.)
cond_value = tf.Variable(False)
def func1():
    return tf.assign(X1, 1.)
def func2():
    return tf.assign(X1, 2.)
cond_result = tf.cond(cond_value, func1, func2)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(sess.run(cond_result))

但可悲的是这个方法只能解决上面提到的第一个问题,第二个超出范围的问题任然存在T T。笔者有做了很多实验,发现这里面的水还相当的深,所以还是建议一般情况还是避开这种使用吧。

猜你喜欢

转载自blog.csdn.net/ziliwangmoe/article/details/81360792
今日推荐