tf.control_dependencies()用法

示例程序1:

观察下面程序可以发现,每更新一次update_a操作,就会更新a和b的值

import tensorflow as tf

a = tf.Variable(initial_value=[1.], dtype=tf.float32)
b = a + 3

update_a = tf.assign(a, b)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

    print(sess.run(update_a))  # 输出4
    print(sess.run(b))  # 输出7,因为更新a后,得到b的值需要重新加3,因此b=4+3=7
    print(sess.run(a))  # 输出4
    print(sess.run(b))  # 输出7

    print(sess.run(update_a))  # 输出7
    print(sess.run(b))  # 输出10,因为更新a后,得到b的值需要重新加3,因此b=7+3=10
    print(sess.run(a))  # 输出7
    print(sess.run(b))  # 输出10

示例程序2:

观察下面程序可以发现,此时tf.control_dependencies([update_a])并未执行update_a操作。因此a和b的值均未改变。

import tensorflow as tf

a = tf.Variable(initial_value=[1.], dtype=tf.float32)
b = a + 3

update_a = tf.assign(a, b)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

    with tf.control_dependencies([update_a]):
        print(sess.run(a))  # 输出1
        print(sess.run(b))  # 输出4

    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

示例程序3:

观察下面程序可以发现,此时执行以下两条语句:

        e = tf.identity(b) + 5
        val = tf.identity(a)

。执行前后的a和b的值均未改变,说明这两条语句未执行update_a操作。

import tensorflow as tf

a = tf.Variable(initial_value=[1.], dtype=tf.float32)
b = a + 3

update_a = tf.assign(a, b)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

    with tf.control_dependencies([update_a]):
        print(sess.run(a))  # 输出1
        print(sess.run(b))  # 输出4

        e = tf.identity(b) + 5
        val = tf.identity(a)

        print(sess.run(a))  # 输出1
        print(sess.run(b))  # 输出4

    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

示例程序4:

观察下面程序可以发现, 执行print(sess.run(e)) 时,先输出9后,再执行update_a操作,update_a=a=b=4,b=a+3=7

import tensorflow as tf

a = tf.Variable(initial_value=[1.], dtype=tf.float32)
b = a + 3

update_a = tf.assign(a, b)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

    with tf.control_dependencies([update_a]):
        print(sess.run(a))  # 输出1
        print(sess.run(b))  # 输出4

        e = tf.identity(b) + 5

        print(sess.run(a))  # 输出1
        print(sess.run(b))  # 输出4

    print(sess.run(e))  # 输出9
    # 输出9后,再执行update_a操作,update_a=a=b=4,b=a+3=7
    print(sess.run(a))  # 输出4
    print(sess.run(b))  # 输出7

    print(sess.run(e))  # 输出12
    # 输出12后,再执行update_a操作,update_a=a=b=7,b=a+3=10
    print(sess.run(a))  # 输出7
    print(sess.run(b))  # 输出10

示例程序5:

观察下面程序可以发现,执行print(sess.run(val))时,在输出4前,先执行update_a操作,update_a=a=b=4,b=a+3=7,然后val=a=4

import tensorflow as tf

a = tf.Variable(initial_value=[1.], dtype=tf.float32)
b = a + 3

update_a = tf.assign(a, b)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

    with tf.control_dependencies([update_a]):
        print(sess.run(a))  # 输出1
        print(sess.run(b))  # 输出4

        val = tf.identity(a)

        print(sess.run(a))  # 输出1
        print(sess.run(b))  # 输出4

    # 输出4前,先执行update_a操作,update_a=a=b=4,b=a+3=7,然后val=a=4
    print(sess.run(val))  # 输出4
    print(sess.run(a))  # 输出4
    print(sess.run(b))  # 输出7

    # 输出7前,先执行update_a操作,update_a=a=b=7,b=a+3=10,然后val=a=7
    print(sess.run(val))  # 输出7
    print(sess.run(a))  # 输出7
    print(sess.run(b))  # 输出10

 示例程序5:

根据以上分析,可以轻松得到以下程序的输出结果

import tensorflow as tf

a = tf.Variable(initial_value=[1.], dtype=tf.float32)
b = a + 3

update_a = tf.assign(a, b)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

    with tf.control_dependencies([update_a]):
        print(sess.run(a))  # 输出1
        print(sess.run(b))  # 输出4

        e = tf.identity(b) + 5
        val = tf.identity(a)

        print(sess.run(a))  # 输出1
        print(sess.run(b))  # 输出4

    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

    print(sess.run(e))  # 输出9,输出9后,再执行update_a操作,update_a=a=b=4,b=a+3=7
    print(sess.run(a))  # 输出4
    print(sess.run(b))  # 输出7

    # 此时,先执行了update_a操作,update_a=a=b=7,b=a+3=10,然后val = a = 7
    print(sess.run(val))  # 输出7
    print(sess.run(a))  # 输出7
    print(sess.run(b))  # 输出10

  示例程序6:通过以下程序分析,print(sess.run([e, val])),意为同时并行执行,互不干扰,且均会执行update_a操作,因此,他们并行输出的结果一致,即产生相同的a和b

import tensorflow as tf

a = tf.Variable(initial_value=[1.], dtype=tf.float32)
b = a + 3
update_a = tf.assign(a, b)

with tf.control_dependencies([update_a]):
    e = tf.identity(b) + 5
    val = tf.identity(a)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4
    print(sess.run([e, val]))
    print(sess.run(a))  # 输出4
    print(sess.run(b))  # 输出7

 示例程序7:通过以下程序分析,先执行print(sess.run(e)) 操作,然后执行print(sess.run(val))操作,即后者的操作是在前者的基础上进行的

import tensorflow as tf

a = tf.Variable(initial_value=[1.], dtype=tf.float32)
b = a + 3
update_a = tf.assign(a, b)

with tf.control_dependencies([update_a]):
    e = tf.identity(b) + 5
    val = tf.identity(a)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4
    print(sess.run(e))  # 输出9
    print(sess.run(val))  # 输出7
    print(sess.run(a))  # 输出7
    print(sess.run(b))  # 输出10



示例程序7:通过以下程序分析,先执行print(sess.run(val))操作,然后执行print(sess.run(e))操作,即后者的操作是在前者的基础上进行的 

import tensorflow as tf

a = tf.Variable(initial_value=[1.], dtype=tf.float32)
b = a + 3
update_a = tf.assign(a, b)

with tf.control_dependencies([update_a]):
    e = tf.identity(b) + 5
    val = tf.identity(a)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(a))  # 输出1
    print(sess.run(b))  # 输出4

    print(sess.run(val))  # 输出4
    print(sess.run(e))  # 输出12

    print(sess.run(a))  # 输出7
    print(sess.run(b))  # 输出10



猜你喜欢

转载自blog.csdn.net/qq_36201400/article/details/108146509
今日推荐