Tensorflow RNN源代码解析笔记2:RNN的基本实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/MebiuW/article/details/62424586

1 前言

话说上回说到了RNNCell的基本实现,本来按照道理来说,应该介绍LSTM GRU的,但是奈何这些于我而言也是不太熟悉(然后我又悲伤的想到了那个电话,哎),所以不如先说说RNN网络的实现吧,毕竟之前熟悉的是基本的RNNCell,现在再熟悉下RNN的具体实现,那不正好可以完整的学习一个基本版的RNN么?
Tensorflow RNN源代码解析笔记1:RNNCell的基本实现

2 基本说明

Tensorflow提供了一个最基本 的rnn网络实现,具体位置是tf.nn.rnn当中。对了,这里说到的RNN网络都需要设定具体的Cell类型(也就是之前那篇文章里的RNNCell),关于RNNCell具体如何工作这里就不说了,这里主要从网络结构上来看。
那么我们首先看下定义这个基本rnn网络的一些参数:


def rnn(cell, inputs, initial_state=None, dtype=None,
        sequence_length=None, scope=None):

这个RNN网络呢,是一个最基本的RNN网络,简单到什么地步呢,基本就是如下几行代码的高上大版,这个rnn所实现的核心功能就是下面这几行代码所实现的,仅仅是增加了一些更多的配置的支持而已:

  ```python
    state = cell.zero_state(...)
    outputs = []
    for input_ in inputs:
      output, state = cell(input_, state)
      outputs.append(output)
    return (outputs, state)
  ```

假设我们输入的序列的长度(时间)为T,那么以此使用Cell运算T次,每次的输出加入到outputs中,并且保留当前的最终state。

这里有两个必选参数和四个可选参数:

cell: 需要为RNN网络提供一个具体的Cell实例(不加修饰的RNNCell或LSTM或GRU或你自己实现的)

inputs:这个就是这个rnn网络的输入了,长度为T的一个ListList的位置表示时间(也就是时间从0到T-1),每个Tensor的形状是[batch_size,input_size],batch_size是你训练的时候,每一个batch的大小,而input_size则是输入元素的维度,这里的两个值要和上面cell里的参数对应。

上面两个参数是必选的,下面几个是可选的:

initial_state:我们知道RNN在运行之初,是需要提前设定一个隐层的状态的,设定就是在这里,一般来说他可以有默认全0的设置。特别注意的是,这里的形状要和cell的结构符合。

dtype:数据类型,不多说

sequence_length:一个[batch_size]的整数数组,里面的第t个元素,表示第t个batch的最大的时间长度是多少。这里的意思是,如果你的这个batch里面,长短不一致,那么可以通过指定batch每个数据的长度,减少运算量,不然batch内所有数据,都会按照时间长度T去运行。如果你有需要这里最好改一下
scppe:VariableScope 不多说

这里,最基本的rnn返回两个东西:

返回 (outputs, state)
outputs : 是一个长度为T的(和输入对应,T代表时间长度)输出,第t个代表,输入了inputs第t个元素后的输出
state :rnn网络处理完输入后,最后剩下的state的状态

3 代码解析

下面我们按照顺序,实际的看下他的代码:

首先呢,就是做一些检查,看你的输入参数、数据是否合法,比如说cell必须是要一个RNNCell的具体实现类(不懂的可以参照第一篇),必须要求输入不为空等。

  if not isinstance(cell, rnn_cell.RNNCell):
    raise TypeError("cell must be an instance of RNNCell")
  if not nest.is_sequence(inputs):
    raise TypeError("inputs must be a sequence")
  if not inputs:
    raise ValueError("inputs must not be empty")

然后就定义了保存输出的地方:

  outputs = []

随后就是建立一个对应的variable_scope:

  # Create a new scope in which the caching device is either
  # determined by the parent scope, or is set to place the cached
  # Variable using the same placement as for the rest of the RNN.
  with vs.variable_scope(scope or "RNN") as varscope:
    if varscope.caching_device is None:
      varscope.set_caching_device(lambda op: op.device)

然后从inputs当中获得第一个输入,从这里面去获取batch_size,input_size等参数,检查一些属性,说实话这段代码我也不是特别想去看,因为我觉得把问题搞复杂了些~~:

 # Obtain the first sequence of the input
    first_input = inputs
    while nest.is_sequence(first_input):
      first_input = first_input[0]

    # Temporarily avoid EmbeddingWrapper and seq2seq badness
    # TODO(lukaszkaiser): remove EmbeddingWrapper
    if first_input.get_shape().ndims != 1:

      input_shape = first_input.get_shape().with_rank_at_least(2)
      fixed_batch_size = input_shape[0]

      flat_inputs = nest.flatten(inputs)
      for flat_input in flat_inputs:
        input_shape = flat_input.get_shape().with_rank_at_least(2)
        batch_size, input_size = input_shape[0], input_shape[1:]
        fixed_batch_size.merge_with(batch_size)
        for i, size in enumerate(input_size):
          if size.value is None:
            raise ValueError(
                "Input size (dimension %d of inputs) must be accessible via "
                "shape inference, but saw value None." % i)
    else:
      fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]

接下来,则是根据可选参数是否提供了,做一些初始化操作,比如initial_state则是会调用cell的那个固有方法去初始化:

  if fixed_batch_size.value:
      batch_size = fixed_batch_size.value
    else:
      batch_size = array_ops.shape(first_input)[0]
    if initial_state is not None:
      state = initial_state
    else:
      if not dtype:
        raise ValueError("If no initial_state is provided, "
                         "dtype must be specified")
      state = cell.zero_state(batch_size, dtype)

这里单独粘贴出来sequence_length参数提供和不提供的区别,提供了大小之后,将会找出包含最大长度、最小长度等之类的参数,帮助后面的计算:

   if sequence_length is not None:  # Prepare variables
      sequence_length = ops.convert_to_tensor(
          sequence_length, name="sequence_length")
      if sequence_length.get_shape().ndims not in (None, 1):
        raise ValueError(
            "sequence_length must be a vector of length batch_size")
      def _create_zero_output(output_size):
        # convert int to TensorShape if necessary
        size = _state_size_with_prefix(output_size, prefix=[batch_size])
        output = array_ops.zeros(
            array_ops.pack(size), _infer_state_dtype(dtype, state))
        shape = _state_size_with_prefix(
            output_size, prefix=[fixed_batch_size.value])
        output.set_shape(tensor_shape.TensorShape(shape))
        return output

      output_size = cell.output_size
      flat_output_size = nest.flatten(output_size)
      flat_zero_output = tuple(
          _create_zero_output(size) for size in flat_output_size)
      zero_output = nest.pack_sequence_as(structure=output_size,
                                          flat_sequence=flat_zero_output)

      sequence_length = math_ops.to_int32(sequence_length)
      min_sequence_length = math_ops.reduce_min(sequence_length)
      max_sequence_length = math_ops.reduce_max(sequence_length)

随后,在配置好所有的参数后,就应该去实际的跑一下RNN了,这里有个区分,那就是如果指定了长度的 话,调用_rnn_step,不指定长度的话,调用call_cell():

 for time, input_ in enumerate(inputs):
      if time > 0: varscope.reuse_variables()
      # pylint: disable=cell-var-from-loop
      call_cell = lambda: cell(input_, state)
      # pylint: enable=cell-var-from-loop
      if sequence_length is not None:
        (output, state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            min_sequence_length=min_sequence_length,
            max_sequence_length=max_sequence_length,
            zero_output=zero_output,
            state=state,
            call_cell=call_cell,
            state_size=cell.state_size)
      else:
        (output, state) = call_cell()

      outputs.append(output)

    return (outputs, state)

那么,我们看下如果指定了长度后,这个要怎么算:

def _rnn_step(
    time, sequence_length, min_sequence_length, max_sequence_length,
    zero_output, state, call_cell, state_size, skip_conditionals=False):

_rnn_step 的用途是计算dynamic rnn当中的某个时间点上的一步,他的返回值是也是(output,state)

猜你喜欢

转载自blog.csdn.net/MebiuW/article/details/62424586