Tensorflow RNN结构 解读

RNN的结构非常多,论文《An Empirical Exploration of Recurrent Network Architectures》中评估了10000种RNN结构。这里我们仅看一下比较流行的LSTM、GRU和多层RNN结构的源码。本文涉及的源码均在rnn_cell_impl.py文件中。

也许你要问了“为什么要看源码?”,是的,如果仅仅是调用TensorFlow中各种工具自是不需要看源码。如果想造点东西但又不想造底层的基础轮子,最起码要知道这些RNN轮子的输入输出是啥吧~ 

在看源码前,我们首先看下RNN的结构

                   

在源码中,RNN的输入是input和state,输出是output和state,我们一起来看一下不同结构rnn的output和state分别是什么。本文从源码的角度看RNN结构,这和RNN的公式是完全对应的。

目录

1 基本的RNN定义

2 LSTM

3 GRU

4 多层RNN


1 基本的RNN定义

class BasicRNNCell(LayerRNNCell):
  """The most basic RNN cell.

  Note that this cell is not optimized for performance. Please use
  `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU.
  """

  def call(self, inputs, state):
    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""

    gate_inputs = math_ops.matmul(
        array_ops.concat([inputs, state], 1), self._kernel) 
    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) # 这里用到广播机制
    output = self._activation(gate_inputs)
    return output, output

RNN的output可以通过一个MLP层(softmax激活函数)得到Vocab上每个词的概率,根据概率大小选择合适的词输出。

2 LSTM

tensorflow中LSTM的实现仅需一行代码

tf.nn.rnn_cell.LSTMCell()

其源码中的核心实现如下

  def call(self, inputs, state):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, must be 2-D, `[batch, input_size]`.
      state: if `state_is_tuple` is False, this must be a state Tensor,
        `2-D, [batch, state_size]`.  If `state_is_tuple` is True, this must be a
        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
        `m_state`.

    Returns:
      A tuple containing:

      - A `2-D, [batch, output_dim]`, Tensor representing the output of the
        LSTM after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - Tensor(s) representing the new state of LSTM after reading `inputs` when
        the previous state was `state`.  Same type and shape(s) as `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    num_proj = self._num_units if self._num_proj is None else self._num_proj
    sigmoid = math_ops.sigmoid

    if self._state_is_tuple:
      (c_prev, m_prev) = state
    else:
      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])

    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    lstm_matrix = math_ops.matmul(
        array_ops.concat([inputs, m_prev], 1), self._kernel)
    lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias)

    i, j, f, o = array_ops.split(
        value=lstm_matrix, num_or_size_splits=4, axis=1)
    # Diagonal connections
    if self._use_peepholes:
      c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
           sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
    else:
      c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
           self._activation(j))

    if self._cell_clip is not None:
      # pylint: disable=invalid-unary-operand-type
      c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
      # pylint: enable=invalid-unary-operand-type
    if self._use_peepholes:
      m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
    else:
      m = sigmoid(o) * self._activation(c)

    if self._num_proj is not None:
      m = math_ops.matmul(m, self._proj_kernel)

      if self._proj_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
        # pylint: enable=invalid-unary-operand-type

    new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
                 array_ops.concat([c, m], 1))
    return m, new_state

可以看到,LSTM的output是ht,state是(ct,ht)。ht是隐层状态,ct是细胞状态。

3 GRU

tensorflow中GRU的实现仅需一行代码

tf.nn.rnn_cell.GRUCell()

其源码中的核心实现如下

  def call(self, inputs, state):
    """Gated recurrent unit (GRU) with nunits cells."""

    gate_inputs = math_ops.matmul(
        array_ops.concat([inputs, state], 1), self._gate_kernel)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)

    value = math_ops.sigmoid(gate_inputs)
    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)

    r_state = r * state

    candidate = math_ops.matmul(
        array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
    candidate = nn_ops.bias_add(candidate, self._candidate_bias)

    c = self._activation(candidate)
    new_h = u * state + (1 - u) * c
    return new_h, new_h

可以看到,GRU的output是ht,state也是ht。

4 多层RNN

tensorflow中多层RNN的实现如下,以多层GRU为例

decoder_cell_list = []
for i in range(pt.dec_num_layers):
    cell = tf.nn.rnn_cell.GRUCell()
    cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=pt.keep_prob)
    decoder_cell_list.append(cell)
decoder_cell = tf.nn.rnn_cell.MultiRNNCell(decoder_cell_list)

MultiRNNCell的源码的核心实现如下

  def call(self, inputs, state):
    """Run this multi-layer cell on inputs, starting from state."""
    cur_state_pos = 0
    cur_inp = inputs
    new_states = []
    for i, cell in enumerate(self._cells):
      with vs.variable_scope("cell_%d" % i):
        if self._state_is_tuple:
          if not nest.is_sequence(state):
            raise ValueError(
                "Expected state to be a tuple of length %d, but received: %s" %
                (len(self.state_size), state))
          cur_state = state[i]
        else:
          cur_state = array_ops.slice(state, [0, cur_state_pos],
                                      [-1, cell.state_size])
          cur_state_pos += cell.state_size
        cur_inp, new_state = cell(cur_inp, cur_state)
        new_states.append(new_state)

    new_states = (tuple(new_states) if self._state_is_tuple else
                  array_ops.concat(new_states, 1))

    return cur_inp, new_states

对于多层的RNN,output是最后一层的ht,state是各层state的tuple。

发布了34 篇原创文章 · 获赞 20 · 访问量 6万+

猜你喜欢

转载自blog.csdn.net/G_B_L/article/details/104006362