Tensorflow源码分析--accumulate_n

Tensorflow源码分析–accumulate_n

标签(空格分隔): Tensorflow


例子:

import tensorflow as tf
x = tf.constant([[1,2],[3,4]])
y = tf.constant([[1,2],[3,4]])
result = tf.accumulate_n([x,y])
sess = tf.Session()
print(sess.run(result))  ## 打印生成的结果
print(type(sess.run(result))) ## 打印生成的类型
print(type(y)) ## 打印常量的类型
》》》[[2 4]
[6 8]]
》》》<class 'numpy.ndarray'>
》》》<class 'tensorflow.python.framework.ops.Tensor'>

注释:这个方法可以节省内存,在多次输入的情况下,会边做边输入
“”“Returns the element-wise sum of a list of tensors.

Optionally, pass shape and tensor_dtype for shape and type checking,
otherwise, these are inferred.

tf.accumulate_n performs the same operation as tf.add_n, but does not wait for all of its inputs to be ready before beginning to sum. This can save memory if inputs are ready at different times, since minimum temporary storage is proportional to the output size rather than the inputs size.

accumulate_n is differentiable (but wasn’t previous to TensorFlow 1.7).

Args:
inputs: A list of Tensor objects, each with same shape and type.
shape: Shape of elements of inputs.
tensor_dtype: The type of inputs.
name: A name for the operation (optional).

Returns:
A Tensor of same shape and type as the elements of inputs.

Raises:
ValueError: If inputs don’t all have same shape and dtype or the shape
cannot be inferred.
“”“

源码:

def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):

  def _input_error():
    return ValueError("inputs must be a list of at least one Tensor with the "
                      "same dtype and shape")
    ## 如果没有输入、输入不是列表或者元祖、并不是输入中的所有元素是Tensor、输入的dtype也必须相同,否则会报错
  if not inputs or not isinstance(inputs, (list, tuple)):
    raise _input_error()
  inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
  if not all(isinstance(x, ops.Tensor) for x in inputs):
    raise _input_error()
  if not all(x.dtype == inputs[0].dtype for x in inputs):
    raise _input_error()
    ## 如果shape不是None,则shape是输入的shape或者为unknown_shape 
  if shape is not None:
    shape = tensor_shape.as_shape(shape)
  else:
    shape = tensor_shape.unknown_shape()
  for input_tensor in inputs:
    if isinstance(input_tensor, ops.Tensor):
      shape = shape.merge_with(input_tensor.get_shape())

  # tensor_dtype is for safety only; operator's output type computed in C++
  if tensor_dtype is not None and tensor_dtype != inputs[0].dtype:
    raise TypeError("tensor_dtype is {}, but input is of type {}".format(
        tensor_dtype, inputs[0].dtype))

  if len(inputs) == 1 and name is None:
    return inputs[0]
  elif len(inputs) == 1 and name is not None:
    return array_ops.identity(inputs[0], name=name)
  elif context.executing_eagerly():
    # TemporaryVariable not currently supported in eager mode; fall back
    # onto AddN for now.
    # TODO(frreiss) remove this once the lifetime of eager variables gets
    # addressed
    return add_n(inputs, name=name)
  else:
    return gen_math_ops.accumulate_nv2(inputs, name=name, shape=shape)  # pylint: disable=protected-access这里会引用另一个方法去做这个答案

猜你喜欢

转载自blog.csdn.net/jiangzhenkang/article/details/80723891