tf.control_dependencies() and tf.while_loop()

tf.control_dependencies()设计是用来控制计算流图的,给图中的某些计算指定顺序。比如:我们想要获取参数更新后的值,那么我们可以这么组织我们的代码。


opt = tf.train.Optimizer().minize(loss)

with tf.control_dependencies([opt]):
  updated_weight = tf.identity(weight)

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  sess.run(updated_weight, feed_dict={...}) # 这样每次得到的都是更新后的weight

关于tf.control_dependencies的具体用法,請移步官网https://www.tensorflow.org/api_docs/python/tf/Graph#control_dependencies,总结一句话就是,在执行某些op,tensor之前,某些op,tensor得首先被运行。

下面说明两种 control_dependencies 不 work 的情况

下面有两种情况,control_dependencies不work,其实并不是它真的不work,而是我们的使用方法有问题。

第一种情况:

import tensorflow as tf
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)

ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
    ema_val = ema.average(update)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(3):
        print(sess.run([ema_val]))

也许你会觉得,在我们 sess.run([ema_val]), ema_op 都会被先执行,然后再计算ema_val,实际情况并不是这样,为什么? 
有兴趣的可以看一下源码,就会发现 ema.average(update) 不是一个 op,它只是从ema对象的一个字典中取出键对应的tensor 而已,然后赋值给ema_val。这个 tensor是由一个在 tf.control_dependencies([ema_op]) 外部的一个 op 计算得来的,所以 control_dependencies会失效。解决方法也很简单,看代码:

import tensorflow as tf
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)

ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
    ema_val = tf.identity(ema.average(update)) #一个identity搞定

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(3):
        print(sess.run([ema_val]))

第二种情况: 这个情况一般不会碰到,这是我在测试 control_dependencies 发现的

import tensorflow as tf
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)

ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
    w1 = tf.Variable(2.0)
    ema_val = ema.average(update)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(3):
        print(sess.run([ema_val, w1]))

这种情况下,control_dependencies也不 work。读取 w1 的值并不会触发 ema_op, 原因请看代码:

#这段代码出现在Variable类定义文件中第287行,
# 在创建Varible时,tensorflow是移除了dependencies了的
#所以会出现 control 不住的情况
with ops.control_dependencies(None):
    ...      
  • 明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/u012436149/article/details/72084744
  • tf.while_loop 可以这样理解
  • loop = []
    while cond(loop):
        loop = body(loop)
  • 即loop参数先传入cond 判断条件是否成立,成立之后,把 loop参数传入body 执行操作, 然后返回 操作后的 loop 参数,即loop参数已被更新,再把更新后的参数传入cond, 依次循环,直到不满足条件。
  • 我们来看这样一个场景如何在 tensorflow中实现

    i= 0
    n = maxiteration
    while(i < n):
        i = i +1
  •  
  • 首先这个要有个判断条件的语句 即

    i  < n
  •  
  • 满足条件就执行循环体里的操作,这个判断条件在tensorflow里,要写个函数来代替即

    def cond(i, n):
        return i < n
  •  
  • 之后是循环体里的操作,也要一个函数来代替即

    def body(i, n):
        i = i + 1
        return i, n
  •  
  • 请注意body函数里虽然没有与参数 n 有关的操作,但是必须要传入参数 n, 因为正如前面所说,要构成循环,参数在body函数更新后还要返回给cond函数,判断是否满足条件,如果不传入参数 n 下次,就没法判断了。

    合起来总得代码为

    i  = 0
    n = maxiteration 
    
    def cond(i, n):
        return i < n
    
    def body(i, n):
        i = i + 1
        return i, n
    i, n = tf.while_loop(cond, body, [i, n])
  • 代码
  • import tensorflow as tf 
    i = tf.get_variable("ii", dtype=tf.int32, shape=[], initializer=tf.ones_initializer())
    n = tf.constant(40)
    
    def cond(i, n):
        return  i < n
    def body(a, n):
        a = i + 1
        return a, n
    
    i, n = tf.while_loop(cond, body, [i, n])
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        res = sess.run([i, n])
        print(res)

猜你喜欢

转载自blog.csdn.net/qq_34504481/article/details/81908137
tf
今日推荐