Tensorflow 动态RNN源码 初探

RNN在深度学习中占据重要地位,我们常常调用tensorflow的包就可以完成RNN的构建与训练,但通用的RNN并不总是能满足我们的需求,若要改动,必先知其细。下面我们根据源码对RNN的实现一探究竟。在探究之前,先来说一下什么叫动态RNN,我们都知道RNN全名是循环神经网络,循环嘛,自然是动态的,通过循环的方式动态生成一个个的token,直至整句话生成完毕时停止。

目录

1 tensorflow 版本

2 动态RNN实现“三板斧”


1 tensorflow 版本

import tensorflow as tf
tf.__version__   # tensorflow版本为1.12.0

2 动态RNN实现“三板斧”

如果定制自己需要的动态RNN,只需要修改三板斧中的对应函数,即可将自己的想法融入tf框架中,无需自己从0实现一个动态RNN,原因有二,一是方便,二是自己从0实现的不一定比tf的写的好emm

第一板斧负责决定当前时间步的输出(sample函数)和下一时间步的输入(next_inputs函数)。

helper = tf.contrib.seq2seq.TrainingHelper(inputs=input_embed,…)  # input_embed是rnn输入字符的embedding

上述函数在文件helper.py中,是专用于训练时候的helper,除此之外,helper.py中还有适用于inference时候的helper,一起来看看源码(下面为源码的重要部分截取,不是完整的helper.py文件,下同),关键是sample函数和next_inputs函数的实现。

# helper.py中所有的class,除了用于训练的TrainingHelper,
# 还有一些用于推断时候的helper,甚至可以自定义,即CustomHelper。
# 对于每个helper,关键在于sample函数和next_inputs函数的实现。

__all__ = [
    "Helper",
    "TrainingHelper",
    "GreedyEmbeddingHelper",
    "SampleEmbeddingHelper",
    "CustomHelper",
    "ScheduledEmbeddingTrainingHelper",
    "ScheduledOutputTrainingHelper",
    "InferenceHelper",
]

#训练阶段 以TrainingHelper为例进行分析
class TrainingHelper(Helper):

  def __init__(self, inputs, sequence_length, time_major=False, name=None):
    initial部分的源码不进行粘贴

  def sample(self, time, outputs, name=None, **unused_kwargs):
    # 采样得到当前时间步的输出token
    with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
      sample_ids = math_ops.cast(
          math_ops.argmax(outputs, axis=-1), dtypes.int32)  # 取概率最大的token作为输出
      return sample_ids

  def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
    """next_inputs_fn for TrainingHelper."""
    with ops.name_scope(name, "TrainingHelperNextInputs",
                        [time, outputs, state]):
      next_time = time + 1
      finished = (next_time >= self._sequence_length)
      all_finished = math_ops.reduce_all(finished)
      def read_from_ta(inp):
        return inp.read(next_time)
      # 若rnn未finished,则取当前时间步输出真值作为下一步的输入。因为训练阶段是有标签的
      next_inputs = control_flow_ops.cond(
          all_finished, lambda: self._zero_inputs,
          lambda: nest.map_structure(read_from_ta, self._input_tas))
      return (finished, next_inputs, state)

# 推断阶段的helper以GreedyEmbeddingHelper为例进行分析
class GreedyEmbeddingHelper(Helper):
  def sample(self, time, outputs, state, name=None):
    """sample for GreedyEmbeddingHelper."""
    del time, state  # unused by sample_fn
    # Outputs are logits, use argmax to get the most probable id
    if not isinstance(outputs, ops.Tensor):
      raise TypeError("Expected outputs to be a single Tensor, got: %s" %
                      type(outputs))
    sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
    return sample_ids

  def next_inputs(self, time, outputs, state, sample_ids, name=None):
    """next_inputs_fn for GreedyEmbeddingHelper."""
    del time, outputs  # unused by next_inputs_fn
    finished = math_ops.equal(sample_ids, self._end_token)
    all_finished = math_ops.reduce_all(finished)
    
    # 因为是推断阶段,所以把当前时间步的输出的预测值作为下一步的输入。
    # sample_ids是token id,所以用_embedding_fn函数得到其embedding后再作为next_inputs
    next_inputs = control_flow_ops.cond(
        all_finished,
        # If we're finished, the next_inputs value doesn't matter
        lambda: self._start_inputs,
        lambda: self._embedding_fn(sample_ids))
    return (finished, next_inputs, state)

第二板斧负责执行一个时间步(step函数),调用cell得到该时间步的输出概率,调用helper得到该时间步的输出token id和下一步的输入token的embedding。

decoder = tf.contrib.seq2seq.BasicDecoder(cell=rnn_cell, helper=helper,…)

上述函数在basic_decoder.py中,BasicDecoder类继承于Decoder类(Decoder类在decoder.py文件中,和dynamic_decode函数在一个文件中),实现了Decoder类中的step函数。

其他的Decoder比如BeamSearchDecoder也继承于Decoder类,实现了Decoder类中的step函数。

所以,如果想自己实现一个decoder的话,继承Decoder类并实现step函数即可。

扫描二维码关注公众号,回复: 9571084 查看本文章
  def step(self, time, inputs, state, name=None):
    """Perform a decoding step.

    Args:
      time: scalar `int32` tensor.
      inputs: A (structure of) input tensors.
      state: A (structure of) state tensors and TensorArrays.
      name: Name scope for any created operations.

    Returns:
      `(outputs, next_state, next_inputs, finished)`.
    """
    with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
      cell_outputs, cell_state = self._cell(inputs, state)
      if self._output_layer is not None:
        cell_outputs = self._output_layer(cell_outputs)
      sample_ids = self._helper.sample(
          time=time, outputs=cell_outputs, state=cell_state)
      (finished, next_inputs, next_state) = self._helper.next_inputs(
          time=time,
          outputs=cell_outputs,
          state=cell_state,
          sample_ids=sample_ids)
    outputs = BasicDecoderOutput(cell_outputs, sample_ids)
    return (outputs, next_state, next_inputs, finished)

可以看出,step函数中调用了cell来得到当前时间步的输出,这里的cell是rnn_cell,定义了RNN的结构,所以了解cell的输入与输出是什么很重要,这样才能正确调用。如果想了解常用rnn_cell的结构,可以阅读Tensorflow RNN结构 解读

第三板斧负责模拟RNN在每个时间步的情况,并在合适的时刻(比如遇到eos或者达到指定的最大长度)停止。

final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder=decoder,…)

上述函数位于decoder.py文件中,dynamic_decode是一个loop(循环)来得到全部时间步的情况,每个时间步都调用decoder。

"""
condition是判断是否停止的条件,body中会调用decoder.step()来得到相关信息。

loop_vars是在循环中不断变化更新的变量,这些变量需要输入到body函数中,
在body函数中计算更新并return,以作为下一个循环body函数的输入。

这里res的内容其实就是body函数返回的内容,也就是loop_vars的值。
"""
res = control_flow_ops.while_loop(
        condition,
        body,
        loop_vars=(
            initial_time,
            initial_outputs_ta,
            initial_state,
            initial_inputs,
            initial_finished,
            initial_sequence_lengths,
        ),
        parallel_iterations=parallel_iterations,
        maximum_iterations=maximum_iterations,
        swap_memory=swap_memory)

所以,如果想定制自己需要的动态RNN,要想清楚loop_vars有哪些,然后写到loop_vars中哦~

发布了34 篇原创文章 · 获赞 20 · 访问量 6万+

猜你喜欢

转载自blog.csdn.net/G_B_L/article/details/104003862