tensorflow-易错函数

1. tf.assign()函数说明:

def assign(ref, value, validate_shape=None, use_locking=None, name=None):
  """Update 'ref' by assigning 'value' to it.

  This operation outputs a Tensor that holds the new value of 'ref' after
    the value has been assigned. This makes it easier to chain operations
    that need to use the reset value.

  Args:
    ref: A mutable `Tensor`.
      Should be from a `Variable` node. May be uninitialized.
    value: A `Tensor`. Must have the same type as `ref`.
      The value to be assigned to the variable.
    validate_shape: An optional `bool`. Defaults to `True`.
      If true, the operation will validate that the shape
      of 'value' matches the shape of the Tensor being assigned to.  If false,
      'ref' will take on the shape of 'value'.
    use_locking: An optional `bool`. Defaults to `True`.
      If True, the assignment will be protected by a lock;
      otherwise the behavior is undefined, but may exhibit less contention.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` that will hold the new value of 'ref' after
      the assignment has completed.
  """
  if ref.dtype._is_ref_dtype:
    return gen_state_ops.assign(
        ref, value, use_locking=use_locking, name=name,
        validate_shape=validate_shape)
  return ref.assign(value)

注意:

1. 只有tf.assign()操作完成以后,张量才能拥有new value

2.参数validate_shape默认值为True:old value的shape必须与new value的shape相同,否则会报错。如下:

import tensorflow as tf
a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30,1])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("run a : ",sess.run(a))
    print("run b : ",sess.run(b))
    print("run a again : ",sess.run(a))
#由于tf.assign()的参数validate_shape的默认值True且[20,30,1]与[10,20]的shape不一样
#报错:ValueError: Dimension 0 in both shapes must be equal, but are 2 and 3. Shapes #   are [2] and [3]. for 'Assign' (op: 'Assign') with input shapes: [2], [3].



import tensorflow as tf
a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30,1],validate_shape=False)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("run a : ",sess.run(a))
    print(a)
    print("run b : ",sess.run(b))
    print("run a again : ",sess.run(a))
out:
run a :  [10 20]
<tf.Variable 'Variable:0' shape=(2,) dtype=int32_ref>
run b :  [20 30  1]
run a again :  [20 30  1]



a = tf.Variable([10, 20])
b = tf.assign(a, [20, 30])
c = b + [10, 20]
with tf.Session() as sess:
     sess.run(tf.global_variables_initializer())
     print(sess.run(a)) # => [10 20] 
     print(sess.run(c)) # => [30 50] 运行c的时候,由于c中含有b,所以b也被运行了
     print(sess.run(a)) # => [20 30]

#tf.assing()未执行,ref不更新新


  

猜你喜欢

转载自blog.csdn.net/biubiubiu888/article/details/82055476