tf.assign()函数解析

这两天看batch normalization的代码时,碰到一个函数tf.train.ExponentialMovingAverage(),在样例代码中看到了tf.assign()函数,特此记录。

tf.assign(
		ref, 
		value, 
		validate_shape=None, 
		use_locking=None, 
		name=None)

参数:

  • ref是一个可变的张量。应该来自一个变量节点。可能是未初始化的。

  • value一个张量。必须具有与ref相同的类型。要分配给变量的值。

Returns :
A Tensor that will hold the new value of ‘ref’ after the assignment has completed.

一个’张量’,它将在赋值完成后保持ref的新值。

Same as “ref”. Returned as a convenience for operations that want to use the new value after the variable has been reset.

将value赋值给ref,并输出ref,即ref=value。为希望在重置变量后使用新值的操作提供方便。

例子1:

import tensorflow as tf

a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30])
c = a + [10, 20]
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("a:",sess.run(a)) # => [10 20]
    print("c:",sess.run(c)) # => [10 20]+[10 20]=[20 40],因为b没有被run所以a还是[10 20]
    print("b:",sess.run(b)) # => ref:a=[20 30],运行b,对a进行assign
    print("a_1:",sess.run(a)) # => [20 30],因为b被run过了,所以a为[20 30]
    print("c_1:",sess.run(c)) # => [20 30]+[10 20]=[30 50],因为b被run过了,所以a为[20,30], 那么c就是[30 50]
> a: [10 20]
> c: [20 40]
> b: [20 30]
> a_1:[20 30]
> c_1: [30 50]

returns中强调需要在assign()被执行了才会返回新值,根据上面的例子1我们就好理解多了,下面再来看一下例子。

例子2:

import tensorflow as tf

a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30])
c = b + [10, 20]
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(a)) # => [10 20] 
    print(sess.run(c)) # => [30 50],运行c的时候,由于c中含有b,所以b也被运行了
    print(sess.run(a)) # => [20 30]
> [10 20]
> [30 50]
> [20 30]

所以我们发现如果assign未被执行,那么ref值就不更新。

猜你喜欢

转载自blog.csdn.net/TeFuirnever/article/details/88906895