tensorflow dynamic rnn源码分析

python3.6,tensorflow1.11

测试代码:

tensorflow在eager模式下进行测试,方便调试,查看中间结果

 1 import tensorflow as tf
 2 
 3 tf.enable_eager_execution()
 4 
 5 batch_size = 4 
 6 input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32)
 7 cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True)
 8 init_state = cell.zero_state(batch_size, dtype=tf.float32)
 9 seq_length = tf.constant([2,3,2,3],dtype=tf.int32)
10 import pdb; pdb.set_trace()
11 output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state,sequence_length=seq_length,time_major=True) #time_major如果是True,就表示RNN的steps用第一个维度表示,建议用这个,运行速度快一点。
12 #如果是False,那么输入的第二个维度就是steps。
13 #如果是True,output的维度是[steps, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和输入是一样的
14 #final_state就是整个LSTM输出的最终的状态,包含c和h。c和h的维度都是[batch_size, n_hidden]

tf.nn.dynamic_rnn在tensorflow/python/ops/rnn.py中定义,进入其中调试

  1 @tf_export("nn.dynamic_rnn")
  2 def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
  3                 dtype=None, parallel_iterations=None, swap_memory=False,
  4                 time_major=False, scope=None):
  5   """Creates a recurrent neural network specified by RNNCell `cell`.
  6 
  7   Performs fully dynamic unrolling of `inputs`.
  8 
  9   Example:
 10 
 11   ```python
 12   # create a BasicRNNCell
 13   rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
 14 
 15   # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
 16 
 17   # defining initial state
 18   initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
 19 
 20   # 'state' is a tensor of shape [batch_size, cell_state_size]
 21   outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
 22                                      initial_state=initial_state,
 23                                      dtype=tf.float32)
 24   ```
 25 
 26   ```python
 27   # create 2 LSTMCells
 28   rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
 29 
 30   # create a RNN cell composed sequentially of a number of RNNCells
 31   multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
 32 
 33   # 'outputs' is a tensor of shape [batch_size, max_time, 256]
 34   # 'state' is a N-tuple where N is the number of LSTMCells containing a
 35   # tf.contrib.rnn.LSTMStateTuple for each cell
 36   outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
 37                                      inputs=data,
 38                                      dtype=tf.float32)
 39   ```
 40 
 41 
 42   Args:
 43     cell: An instance of RNNCell.
 44     inputs: The RNN inputs.
 45       If `time_major == False` (default), this must be a `Tensor` of shape:
 46         `[batch_size, max_time, ...]`, or a nested tuple of such
 47         elements.
 48       If `time_major == True`, this must be a `Tensor` of shape:
 49         `[max_time, batch_size, ...]`, or a nested tuple of such
 50         elements.
 51       This may also be a (possibly nested) tuple of Tensors satisfying
 52       this property.  The first two dimensions must match across all the inputs,
 53       but otherwise the ranks and other shape components may differ.
 54       In this case, input to `cell` at each time-step will replicate the
 55       structure of these tuples, except for the time dimension (from which the
 56       time is taken).
 57       The input to `cell` at each time step will be a `Tensor` or (possibly
 58       nested) tuple of Tensors each with dimensions `[batch_size, ...]`.
 59     sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
 60       Used to copy-through state and zero-out outputs when past a batch
 61       element's sequence length.  So it's more for performance than correctness.
 62     initial_state: (optional) An initial state for the RNN.
 63       If `cell.state_size` is an integer, this must be
 64       a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
 65       If `cell.state_size` is a tuple, this should be a tuple of
 66       tensors having shapes `[batch_size, s] for s in cell.state_size`.
 67     dtype: (optional) The data type for the initial state and expected output.
 68       Required if initial_state is not provided or RNN state has a heterogeneous
 69       dtype.
 70     parallel_iterations: (Default: 32).  The number of iterations to run in
 71       parallel.  Those operations which do not have any temporal dependency
 72       and can be run in parallel, will be.  This parameter trades off
 73       time for space.  Values >> 1 use more memory but take less time,
 74       while smaller values use less memory but computations take longer.
 75     swap_memory: Transparently swap the tensors produced in forward inference
 76       but needed for back prop from GPU to CPU.  This allows training RNNs
 77       which would typically not fit on a single GPU, with very minimal (or no)
 78       performance penalty.
 79     time_major: The shape format of the `inputs` and `outputs` Tensors.
 80       If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
 81       If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
 82       Using `time_major = True` is a bit more efficient because it avoids
 83       transposes at the beginning and end of the RNN calculation.  However,
 84       most TensorFlow data is batch-major, so by default this function
 85       accepts input and emits output in batch-major form.
 86     scope: VariableScope for the created subgraph; defaults to "rnn".
 87 
 88   Returns:
 89     A pair (outputs, state) where:
 90 
 91     outputs: The RNN output `Tensor`.
 92 
 93       If time_major == False (default), this will be a `Tensor` shaped:
 94         `[batch_size, max_time, cell.output_size]`.
 95 
 96       If time_major == True, this will be a `Tensor` shaped:
 97         `[max_time, batch_size, cell.output_size]`.
 98 
 99       Note, if `cell.output_size` is a (possibly nested) tuple of integers
100       or `TensorShape` objects, then `outputs` will be a tuple having the
101       same structure as `cell.output_size`, containing Tensors having shapes
102       corresponding to the shape data in `cell.output_size`.
103 
104     state: The final state.  If `cell.state_size` is an int, this
105       will be shaped `[batch_size, cell.state_size]`.  If it is a
106       `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
107       If it is a (possibly nested) tuple of ints or `TensorShape`, this will
108       be a tuple having the corresponding shapes. If cells are `LSTMCells`
109       `state` will be a tuple containing a `LSTMStateTuple` for each cell.
110 
111   Raises:
112     TypeError: If `cell` is not an instance of RNNCell.
113     ValueError: If inputs is None or an empty list.
114   """
115   rnn_cell_impl.assert_like_rnncell("cell", cell)
116 
117   with vs.variable_scope(scope or "rnn") as varscope:
118     # Create a new scope in which the caching device is either
119     # determined by the parent scope, or is set to place the cached
120     # Variable using the same placement as for the rest of the RNN.
121     if _should_cache():
122       if varscope.caching_device is None:
123         varscope.set_caching_device(lambda op: op.device)
124 
125     # By default, time_major==False and inputs are batch-major: shaped
126     #   [batch, time, depth]
127     # For internal calculations, we transpose to [time, batch, depth]
128     flat_input = nest.flatten(inputs)
129 
130     if not time_major:
131       # (B,T,D) => (T,B,D)
132       flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
133       flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
134 
135     parallel_iterations = parallel_iterations or 32
136     if sequence_length is not None:
137       sequence_length = math_ops.to_int32(sequence_length)
138       if sequence_length.get_shape().ndims not in (None, 1):
139         raise ValueError(
140             "sequence_length must be a vector of length batch_size, "
141             "but saw shape: %s" % sequence_length.get_shape())
142       sequence_length = array_ops.identity(  # Just to find it in the graph.
143           sequence_length, name="sequence_length")
144 
145     batch_size = _best_effort_input_batch_size(flat_input)
146 
147     if initial_state is not None:
148       state = initial_state
149     else:
150       if not dtype:
151         raise ValueError("If there is no initial_state, you must give a dtype.")
152       if getattr(cell, "get_initial_state", None) is not None:
153         state = cell.get_initial_state(
154             inputs=None, batch_size=batch_size, dtype=dtype)
155       else:
156         state = cell.zero_state(batch_size, dtype)
157 
158     def _assert_has_shape(x, shape):
159       x_shape = array_ops.shape(x)
160       packed_shape = array_ops.stack(shape)
161       return control_flow_ops.Assert(
162           math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)),
163           ["Expected shape for Tensor %s is " % x.name,
164            packed_shape, " but saw shape: ", x_shape])
165 
166     if not context.executing_eagerly() and sequence_length is not None:
167       # Perform some shape validation
168       with ops.control_dependencies(
169           [_assert_has_shape(sequence_length, [batch_size])]):
170         sequence_length = array_ops.identity(
171             sequence_length, name="CheckSeqLen")
172 
173     inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
174 
175     (outputs, final_state) = _dynamic_rnn_loop(
176         cell,
177         inputs,
178         state,
179         parallel_iterations=parallel_iterations,
180         swap_memory=swap_memory,
181         sequence_length=sequence_length,
182         dtype=dtype)
183 
184     # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
185     # If we are performing batch-major calculations, transpose output back
186     # to shape [batch, time, depth]
187     if not time_major:
188       # (T,B,D) => (B,T,D)
189       outputs = nest.map_structure(_transpose_batch_time, outputs)
190 
191     return (outputs, final_state)

最后调用_dynamic_rnn_loop

  1 def _dynamic_rnn_loop(cell,
  2                       inputs,
  3                       initial_state,
  4                       parallel_iterations,
  5                       swap_memory,
  6                       sequence_length=None,
  7                       dtype=None):
  8   """Internal implementation of Dynamic RNN.
  9 
 10   Args:
 11     cell: An instance of RNNCell.
 12     inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
 13       tuple of such elements.
 14     initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
 15       `cell.state_size` is a tuple, then this should be a tuple of
 16       tensors having shapes `[batch_size, s] for s in cell.state_size`.
 17     parallel_iterations: Positive Python int.
 18     swap_memory: A Python boolean
 19     sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
 20     dtype: (optional) Expected dtype of output. If not specified, inferred from
 21       initial_state.
 22 
 23   Returns:
 24     Tuple `(final_outputs, final_state)`.
 25     final_outputs:
 26       A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
 27       `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
 28       objects, then this returns a (possibly nested) tuple of Tensors matching
 29       the corresponding shapes.
 30     final_state:
 31       A `Tensor`, or possibly nested tuple of Tensors, matching in length
 32       and shapes to `initial_state`.
 33   Raises:
 34     ValueError: If the input depth cannot be inferred via shape inference
 35       from the inputs.
 36   """
 37   import pdb;pdb.set_trace()
 38   state = initial_state
 39   assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
 40 
 41   state_size = cell.state_size#LSTMStateTuple(c=10, h=10)
 42 
 43   flat_input = nest.flatten(inputs)#list,~[0].shape=TensorShape([Dimension(3), Dimension(4), Dimension(6)])
 44   flat_output_size = nest.flatten(cell.output_size)#[10]
 45 
 46   # Construct an initial output
 47   input_shape = array_ops.shape(flat_input[0])#array([3, 4, 6]
 48   time_steps = input_shape[0]#3
 49   batch_size = _best_effort_input_batch_size(flat_input)#4
 50 
 51   inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3)
 52                            for input_ in flat_input)#(TensorShape([Dimension(3), Dimension(4), Dimension(6)]),)
 53 
 54   const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]#3,4
 55 
 56   for shape in inputs_got_shape:
 57     if not shape[2:].is_fully_defined():
 58       raise ValueError(
 59           "Input size (depth of inputs) must be accessible via shape inference,"
 60           " but saw value None.")
 61     got_time_steps = shape[0].value#3
 62     got_batch_size = shape[1].value#4
 63     if const_time_steps != got_time_steps:
 64       raise ValueError(
 65           "Time steps is not the same for all the elements in the input in a "
 66           "batch.")
 67     if const_batch_size != got_batch_size:
 68       raise ValueError(
 69           "Batch_size is not the same for all the elements in the input.")
 70 
 71   # Prepare dynamic conditional copying of state & output
 72   def _create_zero_arrays(size):
 73     size = _concat(batch_size, size)
 74     return array_ops.zeros(
 75         array_ops.stack(size), _infer_state_dtype(dtype, state))
 76 
 77   flat_zero_output = tuple(_create_zero_arrays(output)
 78                            for output in flat_output_size)#tuple,~[0].shape:TensorShape([Dimension(4), Dimension(10)])
 79   zero_output = nest.pack_sequence_as(structure=cell.output_size,
 80                                       flat_sequence=flat_zero_output)#TensorShape([Dimension(4), Dimension(10)])
 81 
 82   if sequence_length is not None:
 83     min_sequence_length = math_ops.reduce_min(sequence_length)#2
 84     max_sequence_length = math_ops.reduce_max(sequence_length)#3
 85   else:
 86     max_sequence_length = time_steps
 87 
 88   time = array_ops.constant(0, dtype=dtypes.int32, name="time")
 89 
 90   with ops.name_scope("dynamic_rnn") as scope:
 91     base_name = scope
 92 
 93   def _create_ta(name, element_shape, dtype):
 94     return tensor_array_ops.TensorArray(dtype=dtype,
 95                                         size=time_steps,
 96                                         element_shape=element_shape,
 97                                         tensor_array_name=base_name + name)
 98 
 99   in_graph_mode = not context.executing_eagerly()
100   if in_graph_mode:
101     output_ta = tuple(
102         _create_ta(
103             "output_%d" % i,
104             element_shape=(tensor_shape.TensorShape([const_batch_size])
105                            .concatenate(
106                                _maybe_tensor_shape_from_tensor(out_size))),
107             dtype=_infer_state_dtype(dtype, state))
108         for i, out_size in enumerate(flat_output_size))
109     input_ta = tuple(
110         _create_ta(
111             "input_%d" % i,
112             element_shape=flat_input_i.shape[1:],
113             dtype=flat_input_i.dtype)
114         for i, flat_input_i in enumerate(flat_input))
115     input_ta = tuple(ta.unstack(input_)
116                      for ta, input_ in zip(input_ta, flat_input))
117   else:
118     output_ta = tuple([0 for _ in range(time_steps.numpy())]
119                       for i in range(len(flat_output_size)))#([0, 0, 0],)
120     input_ta = flat_input##list,~[0].shape=TensorShape([Dimension(3), Dimension(4), Dimension(6)])
121 
122   def _time_step(time, output_ta_t, state):
123     """Take a time step of the dynamic RNN.
124 
125     Args:
126       time: int32 scalar Tensor.
127       output_ta_t: List of `TensorArray`s that represent the output.
128       state: nested tuple of vector tensors that represent the state.
129 
130     Returns:
131       The tuple (time + 1, output_ta_t with updated flow, new_state).
132     """
133     import pdb;pdb.set_trace()
134     if in_graph_mode:
135       input_t = tuple(ta.read(time) for ta in input_ta)
136       # Restore some shape information
137       for input_, shape in zip(input_t, inputs_got_shape):
138         input_.set_shape(shape[1:])
139     else:
140       input_t = tuple(ta[time.numpy()] for ta in input_ta)3#TensorShape([Dimension(4), Dimension(6)])
141 
142     input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)#TensorShape([Dimension(4), Dimension(6)])
143     # Keras RNN cells only accept state as list, even if it's a single tensor.
144     is_keras_rnn_cell = _is_keras_rnn_cell(cell)
145     if is_keras_rnn_cell and not nest.is_sequence(state):
146       state = [state]
147     call_cell = lambda: cell(input_t, state)
148 
149     if sequence_length is not None:
150       (output, new_state) = _rnn_step(
151           time=time,
152           sequence_length=sequence_length,
153           min_sequence_length=min_sequence_length,
154           max_sequence_length=max_sequence_length,
155           zero_output=zero_output,
156           state=state,
157           call_cell=call_cell,
158           state_size=state_size,
159           skip_conditionals=True)
160     else:
161       (output, new_state) = call_cell()
162 
163     # Keras cells always wrap state as list, even if it's a single tensor.
164     if is_keras_rnn_cell and len(new_state) == 1:
165       new_state = new_state[0]
166     # Pack state if using state tuples
167     output = nest.flatten(output)
168 
169     if in_graph_mode:
170       output_ta_t = tuple(
171           ta.write(time, out) for ta, out in zip(output_ta_t, output))
172     else:
173       for ta, out in zip(output_ta_t, output):
174         ta[time.numpy()] = out
175 
176     return (time + 1, output_ta_t, new_state)
177 
178   if in_graph_mode:
179     # Make sure that we run at least 1 step, if necessary, to ensure
180     # the TensorArrays pick up the dynamic shape.
181     loop_bound = math_ops.minimum(
182         time_steps, math_ops.maximum(1, max_sequence_length))
183   else:
184     # Using max_sequence_length isn't currently supported in the Eager branch.
185     loop_bound = time_steps#3
186 
187   _, output_final_ta, final_state = control_flow_ops.while_loop(
188       cond=lambda time, *_: time < loop_bound,
189       body=_time_step,
190       loop_vars=(time, output_ta, state),
191       parallel_iterations=parallel_iterations,
192       maximum_iterations=time_steps,
193       swap_memory=swap_memory)
194 
195   # Unpack final output if not using output tuples.
196   if in_graph_mode:
197     final_outputs = tuple(ta.stack() for ta in output_final_ta)
198     # Restore some shape information
199     for output, output_size in zip(final_outputs, flat_output_size):
200       shape = _concat(
201           [const_time_steps, const_batch_size], output_size, static=True)
202       output.set_shape(shape)
203   else:
204     final_outputs = output_final_ta
205 
206   final_outputs = nest.pack_sequence_as(
207       structure=cell.output_size, flat_sequence=final_outputs)
208   if not in_graph_mode:
209     final_outputs = nest.map_structure_up_to(
210         cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs)
211 
212   return (final_outputs, final_state)

可以看到dynamic_rnn主要是利用while_loop处理不同Batch长度不同的问题

从上面82-86行看出,如果不给sequence_length参数,sequence_length=time_step=input.shape[0],当给定参数sequence_length时,调用_rnn_step函数,对超出长度的部分output设0,这一点在下面代码60,70行实现

  1 def _rnn_step(
  2     time, sequence_length, min_sequence_length, max_sequence_length,
  3     zero_output, state, call_cell, state_size, skip_conditionals=False):
  4   """Calculate one step of a dynamic RNN minibatch.
  5 
  6   Returns an (output, state) pair conditioned on `sequence_length`.
  7   When skip_conditionals=False, the pseudocode is something like:
  8 
  9   if t >= max_sequence_length:
 10     return (zero_output, state)
 11   if t < min_sequence_length:
 12     return call_cell()
 13 
 14   # Selectively output zeros or output, old state or new state depending
 15   # on whether we've finished calculating each row.
 16   new_output, new_state = call_cell()
 17   final_output = np.vstack([
 18     zero_output if time >= sequence_length[r] else new_output_r
 19     for r, new_output_r in enumerate(new_output)
 20   ])
 21   final_state = np.vstack([
 22     state[r] if time >= sequence_length[r] else new_state_r
 23     for r, new_state_r in enumerate(new_state)
 24   ])
 25   return (final_output, final_state)
 26 
 27   Args:
 28     time: int32 `Tensor` scalar.
 29     sequence_length: int32 `Tensor` vector of size [batch_size].
 30     min_sequence_length: int32 `Tensor` scalar, min of sequence_length.
 31     max_sequence_length: int32 `Tensor` scalar, max of sequence_length.
 32     zero_output: `Tensor` vector of shape [output_size].
 33     state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
 34       or a list/tuple of such tensors.
 35     call_cell: lambda returning tuple of (new_output, new_state) where
 36       new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
 37       new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
 38     state_size: The `cell.state_size` associated with the state.
 39     skip_conditionals: Python bool, whether to skip using the conditional
 40       calculations.  This is useful for `dynamic_rnn`, where the input tensor
 41       matches `max_sequence_length`, and using conditionals just slows
 42       everything down.
 43 
 44   Returns:
 45     A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
 46       final_output is a `Tensor` matrix of shape [batch_size, output_size]
 47       final_state is either a single `Tensor` matrix, or a tuple of such
 48         matrices (matching length and shapes of input `state`).
 49 
 50   Raises:
 51     ValueError: If the cell returns a state tuple whose length does not match
 52       that returned by `state_size`.
 53   """
 54   import pdb;pdb.set_trace()
 55   # Convert state to a list for ease of use
 56   flat_state = nest.flatten(state)#[c,h],shape=[4,10]
 57   flat_zero_output = nest.flatten(zero_output)#list,~[0].shape:TensorShape([Dimension(4), Dimension(10)])
 58 
 59   # Vector describing which batch entries are finished.
 60   copy_cond = time >= sequence_length#step1:array([False, False, False, False])
 61 
 62   def _copy_one_through(output, new_output):
 63     # TensorArray and scalar get passed through.
 64     if isinstance(output, tensor_array_ops.TensorArray):
 65       return new_output
 66     if output.shape.ndims == 0:
 67       return new_output
 68     # Otherwise propagate the old or the new value.
 69     with ops.colocate_with(new_output):
 70       return array_ops.where(copy_cond, output, new_output)#多余的取0
 71 
 72   def _copy_some_through(flat_new_output, flat_new_state):
 73     # Use broadcasting select to determine which values should get
 74     # the previous state & zero output, and which values should get
 75     # a calculated state & output.
 76     flat_new_output = [
 77         _copy_one_through(zero_output, new_output)
 78         for zero_output, new_output in zip(flat_zero_output, flat_new_output)]
 79     flat_new_state = [
 80         _copy_one_through(state, new_state)
 81         for state, new_state in zip(flat_state, flat_new_state)]
 82     return flat_new_output + flat_new_state
 83 
 84   def _maybe_copy_some_through():
 85     """Run RNN step.  Pass through either no or some past state."""
 86     new_output, new_state = call_cell()
 87 
 88     nest.assert_same_structure(state, new_state)
 89 
 90     flat_new_state = nest.flatten(new_state)
 91     flat_new_output = nest.flatten(new_output)
 92     return control_flow_ops.cond(
 93         # if t < min_seq_len: calculate and return everything
 94         time < min_sequence_length, lambda: flat_new_output + flat_new_state,
 95         # else copy some of it through
 96         lambda: _copy_some_through(flat_new_output, flat_new_state))
 97 
 98   # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
 99   # but benefits from removing cond() and its gradient.  We should
100   # profile with and without this switch here.
101   if skip_conditionals:
102     # Instead of using conditionals, perform the selective copy at all time
103     # steps.  This is faster when max_seq_len is equal to the number of unrolls
104     # (which is typical for dynamic_rnn).
105     new_output, new_state = call_cell()
106     nest.assert_same_structure(state, new_state)
107     new_state = nest.flatten(new_state)#[c,h],shape=(4, 10)
108     new_output = nest.flatten(new_output)#shape=(4, 10)
109     final_output_and_state = _copy_some_through(new_output, new_state)
110   else:
111     empty_update = lambda: flat_zero_output + flat_state
112     final_output_and_state = control_flow_ops.cond(
113         # if t >= max_seq_len: copy all state through, output zeros
114         time >= max_sequence_length, empty_update,
115         # otherwise calculation is required: copy some or all of it through
116         _maybe_copy_some_through)
117 
118   if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
119     raise ValueError("Internal error: state and output were not concatenated "
120                      "correctly.")
121   final_output = final_output_and_state[:len(flat_zero_output)]
122   final_state = final_output_and_state[len(flat_zero_output):]
123 
124   for output, flat_output in zip(final_output, flat_zero_output):
125     output.set_shape(flat_output.get_shape())
126   for substate, flat_substate in zip(final_state, flat_state):
127     if not isinstance(substate, tensor_array_ops.TensorArray):
128       substate.set_shape(flat_substate.get_shape())
129 
130   final_output = nest.pack_sequence_as(
131       structure=zero_output, flat_sequence=final_output)
132   final_state = nest.pack_sequence_as(
133       structure=state, flat_sequence=final_state)
134 
135   return final_output, final_state

猜你喜欢

转载自www.cnblogs.com/buyizhiyou/p/9883182.html