tf.control_dependencies与tf.identity组合详解

引言

我们在实现神经网络的时候经常会看到tf.control_dependencies的使用,但是这个函数究竟是什么作用,我们应该在什么情况下使用呢?今天我们就来一探究竟。

理解

其实从字面上看,control_dependencies 是控制依赖的意思,我们可以大致推测出来,这个函数应该使用来控制就算图节点之间的依赖的。其实正是如此,tf.control_dependencies()是用来控制计算流图的,给图中的某些节点指定计算的顺序。

原型分析

tf.control_dependencies(self, control_inputs)
 arguments:control_inputs: A list of `Operation` or `Tensor` objects 
which must be executed or computed before running the operations 
defined in the context. (注意这里control_inputs是list)
return:  A context manager that specifies control dependencies 
for all operations constructed within the context.

通过以上的解释,我们可以知道,该函数接受的参数control_inputs,是Operation或者Tensor构成的list。返回的是一个上下文管理器,该上下文管理器用来控制在该上下文中的操作的依赖。也就是说,上下文管理器下定义的操作是依赖control_inputs中的操作的,control_dependencies用来控制control_inputs中操作执行后,才执行上下文管理器中定义的操作。

例子

如果我们想要确保获取更新后的参数,name我们可以这样组织我们的代码。

opt = tf.train.Optimizer().minize(loss)

with tf.control_dependencies([opt]): #先执行opt
  updated_weight = tf.identity(weight)  #再执行该操作

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  sess.run(updated_weight, feed_dict={...}) # 这样每次得到的都是更新后的weight

可以看到以上的例子用到了,tf.identity(),至于为什么要使用tf.identity(),我在下一篇博客:名字中有详细的解释,不懂的可以移步了解。

control_flow_ops.with_dependencies

除了常用tf.control_dependencies()我们还会看到,control_flow_ops.with_dependencies(),其实连个函数都可以实现依赖的控制,只是实现的方式不太一样。

with_dependencies(dependencies, output_tensor, name=None)
Produces the content of `output_tensor` only after `dependencies`.
所有的依赖操作完成后,计算output_tensor并返回
  In some cases, a user may want the output of an operation to be
  consumed externally only after some other dependencies have run
  first. This function ensures returns `output_tensor`, but only after all
  operations in `dependencies` have run. Note that this means that there is
  no guarantee that `output_tensor` will be evaluated after any `dependencies`
  have run.

  See also @{tf.tuple$tuple} and @{tf.group$group}.

  Args:
    dependencies: Iterable of operations to run before this op finishes.
    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
    name: (Optional) A name for this operation.

  Returns:
    Same as `output_tensor`.

  Raises:
    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 

参考链接:

http://blog.csdn.net/m0_37041325/article/details/76943364

https://stackoverflow.com/questions/34877523/in-tensorflow-what-is-tf-identity-used-for

tf.control_dependencies(self, control_inputs)

arguments:control_inputs: A list of `Operation` or `Tensor` objects which must be executed or computed before running the operations defined in the context. (注意这里control_inputs是list)
return:  A context manager that specifies control dependencies for all operations constructed within the context.(返回所有在环境中的控制依赖的上下文管理器)

该方法可以控制操作(op)执行的顺序,不能为tensor

tf.identity(input, name=None) 

Args:
input: A Tensor.
name: A name for the operation (optional).

Returns:A tensor with the same shape and contents as the input tensor or value.

源于StackOverFlow有个关于两者使用的例子:

x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1)

with tf.control_dependencies([x_plus_1]):
    y = x
init = tf.global_variables_initializer()

with tf.Session() as session:
    init.run()
    for i in range(5):
        print(y.eval())


针对此程序,输出结果为:0.0 0.0 0.0 0.0 0.0
输出变量x,结果也为0.0

说明x_plus_1操作并没有被执行,我认为虽然tf.control_dependencies参数中的op列表会在with包含的操作op执行之前先执行,但是y=x这个语句并不是一个op,而是一个tensor,所以执行y=x时,并不会执行tf.control_dependencies参数中的操作op。

所以可以将  y=x 修改为 y=tf.identity(x),此时这个语句就是一个操作op,要先执行tf.control_dependencies参数中的op列表,再执行y=tf.identity(x)操作,最终输出结果为1.0 2.0 3.0 4.0 5.0,最终变量x的结果也为5.0,完整程序如下:

x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1)

with tf.control_dependencies([x_plus_1]):
    y = tf.identity(x)
init = tf.global_variables_initializer()
with tf.Session() as session:
    init.run()
    for i in range(5):
        print(y.eval())
    print(x.eval())
 

猜你喜欢

转载自blog.csdn.net/qq_30638831/article/details/83420296
tf
今日推荐