代码阅读-官方tf版BeamSearch

官方代码

0. BasicSeq2Seq

先从入口看起,BasicSeq2Seq类继承的是Seq2SeqModel类,下面是关于解码的部分。可以看到训练和预测阶段的解码方式是不同的。

@templatemethod("decode")
  def decode(self, encoder_output, features, labels):
    decoder = self._create_decoder(encoder_output, features, labels)
    if self.use_beam_search:
      decoder = self._get_beam_search_decoder(decoder)

    bridge = self._create_bridge(
        encoder_outputs=encoder_output,
        decoder_state_size=decoder.cell.state_size)
    if self.mode == tf.contrib.learn.ModeKeys.INFER:
      return self._decode_infer(decoder, bridge, encoder_output, features,
                                labels)
    else:
      return self._decode_train(decoder, bridge, encoder_output, features,
                                labels)

了解了上面这个函数之后,我们接下来会从两方面继续介绍,一个当然是我们这篇文章要介绍的BeamSearchDecoder了,它通过_get_beam_search_decoder返回;另一个则是bridge,因为这个变量在论文中并没有体现,我们就先来研究一下他是什么吧。

1.Bridge类

这个我是在代码中看到的,论文中并没有。

bridge定义了信息在编码器、解码器之间是如何传递的,所以在编码器和解码器之间是有很多bridge链接的。

比如,encoder之后的是一个 [ b a t c h , m ] [batch, m] 的向量 V e V_e ,而decoder却需要一个[batch size, n]的输入向量 V d V_d m m n n 是可以不一样的。这时就需要bridge类通过不同的逻辑,将 V e V_e 转化为 V d V_d .

来看一下基类的实现:

@six.add_metaclass(abc.ABCMeta)
class Bridge(Configurable):
  """一个抽象类,定义信息如何在解码器编码器之间传输。
  
  Args:
    encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
    decoder_state_size: An integer or tuple of integers defining the
      state size of the decoder.
  """

  def __init__(self, encoder_outputs, decoder_state_size, params, mode):
    Configurable.__init__(self, params, mode)
    self.encoder_outputs = encoder_outputs
    self.decoder_state_size = decoder_state_size
    self.batch_size = tf.shape(
        nest.flatten(self.encoder_outputs.final_state)[0])[0]

  def __call__(self):
    """Runs the bridge function.
    Returns:
      An initial decoder_state tensor or tuple of tensors.
    """
    return self._create()

  @abc.abstractmethod
  def _create(self):
    """ Implements the logic for this bridge.
    This function should be implemented by child classes.
    Returns:
      A tuple initial_decoder_state tensor or tuple of tensors.
    """
    raise NotImplementedError("Must be implemented by child class")

所有的逻辑都在 _create 函数中,具体实现由子类去完成, 该函数返回的是解码器的初始状态。

Bridge有三个子类:ZeroBridge、

扫描二维码关注公众号,回复: 10479412 查看本文章

1.1 ZeroBridge

编解码器之间什么信息都不传,让解码器初始状态位0.

class ZeroBridge(Bridge):
  """A bridge that does not pass any information between encoder and decoder
  and sets the initial decoder state to 0. The input function is not modified.
  """

  @staticmethod
  def default_params():
    return {}

  def _create(self):
    zero_state = nest.map_structure(
        lambda x: tf.zeros([self.batch_size, x], dtype=tf.float32),
        self.decoder_state_size)
    return zero_state

1.2 PassThroughBridge

当且仅当解码器、编码器有相同的状态size(比如使用相同的rnn)时,可以使用,此时 m = n m=n 。此时直接把编码器的输出喂给解码器。

class PassThroughBridge(Bridge):
  """Passes the encoder state through to the decoder as-is. This bridge
  can only be used if encoder and decoder have the exact same state size, i.e.
  use the same RNN cell.
  """

  @staticmethod
  def default_params():
    return {}

  def _create(self):
    nest.assert_same_structure(self.encoder_outputs.final_state,
                               self.decoder_state_size)
    return self.encoder_outputs.final_state

1.3 InitialStateBridge

没有什么问题是不能通过架一层来解决的~所以当 m ! = n m!=n 时,我们通过一个全连接FC 层来完成 V e V_e V d V_d 的映射.

看起来这个是最常用的。而实际从代码上看,也确实使用了这种Bridge

class InitialStateBridge(Bridge):
  """A bridge that creates an initial decoder state based on the output
  of the encoder. This state is created by passing the encoder outputs
  through an additional layer to match them to the decoder state size.
  The input function remains unmodified.

  Args:
    encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
    decoder_state_size: An integer or tuple of integers defining the
      state size of the decoder.
    bridge_input: Which attribute of the `encoder_outputs` to use for the
      initial state calculation. For example, "final_state" means that
      `encoder_outputs.final_state` will be used.
    activation_fn: An optional activation function for the extra
      layer inserted between encoder and decoder. A string for a function
      name contained in `tf.nn`, e.g. "tanh".
  """

  def __init__(self, encoder_outputs, decoder_state_size, params, mode):
    super(InitialStateBridge, self).__init__(encoder_outputs,
                                             decoder_state_size, params, mode)

    if not hasattr(encoder_outputs, self.params["bridge_input"]):
      raise ValueError("Invalid bridge_input not in encoder outputs.")

    self._bridge_input = getattr(encoder_outputs, self.params["bridge_input"])
    self._activation_fn = locate(self.params["activation_fn"])

  @staticmethod
  def default_params():
    return {
        "bridge_input": "final_state",
        "activation_fn": "tensorflow.identity",
    }

  def _create(self):
    # Concat bridge inputs on the depth dimensions
    bridge_input = nest.map_structure(
        lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]),
        self._bridge_input)
    bridge_input_flat = nest.flatten([bridge_input])
    bridge_input_concat = tf.concat(bridge_input_flat, 1)

    state_size_splits = nest.flatten(self.decoder_state_size)
    total_decoder_state_size = sum(state_size_splits)

    # Pass bridge inputs through a fully connected layer layer
    initial_state_flat = tf.contrib.layers.fully_connected(
        inputs=bridge_input_concat,
        num_outputs=total_decoder_state_size,
        activation_fn=self._activation_fn)

    # Shape back into required state size
    initial_state = tf.split(initial_state_flat, state_size_splits, axis=1)
    return nest.pack_sequence_as(self.decoder_state_size, initial_state)


2. BeamSearchDecoder类

其实,除了我们要讲的beam search encoder,还有带attention的encoder,当然这些都是从最基本的decoder发展出来的。

A decoder that uses beam search. Can only be used for inference, not training.

如果解码使用beamsearch,那么batch_size要设置成beam_width

class BeamSearchDecoder(RNNDecoder):
  """The BeamSearchDecoder wraps another decoder to perform beam search instead
  of greedy selection. This decoder must be used with batch size of 1, which
  will result in an effective batch size of `beam_width`.
  
  """

  def __init__(self, decoder, config):
    """
    Args:
    decoder: 一个`RNNDecoder` 的实例,就是使用了rnncell然后再包装一下
    config: 包含了各种参数
    """
    super(BeamSearchDecoder, self).__init__(decoder.params, decoder.mode,
                                            decoder.name)
    self.decoder = decoder
    self.config = config

下面我们看一下,BeamSearchDecoder的每一步step在做什么:

首先,拿到最初的decoder状态和输出

(decoder_output, decoder_state, _, _) = \
        self.decoder.step(time_, inputs,  decoder_state)

其次, 执行这一步的beam search,返回的是这一步beam search的输出和状态。

bs_output, beam_state = beam_search.beam_search_step(
        time_=time_,
        logits=decoder_output.logits,
        beam_state=beam_state,
        config=self.config)

其中,time_是每一个时间步,从0开始,这时我们认为所有的beams都是相同的。
logits是一个[B, vocab_size]的tensor,表明当前时刻的logits;beam_state是当前时刻的状态,是一个BeamState实例config则是相关参数。

2.1 step中

我们深入这个函数看一下:

def beam_search_step(time_, logits, beam_state, config):
    """
    Args:
        释义见代码下方的文字
    Returns:
    
            
    """
    # 计算当前预测结果的长度
    prediction_lengths = beam_state.lengths
    previously_finished = beam_state.finished

    # 计算新假设的总概率大小(取log),维度[beam_width, vocab_size]
    probs = tf.nn.log_softmax(logits)
    ## 把所有已经结束了的树枝`mask`起来,不会继续向下生长
    probs = mask_probs(probs, config.eos_token, previously_finished)
    ## 对于所有既不是终止符也没有停止生长的`continuations`,加1
    total_probs = tf.expand_dims(beam_state.log_probs, 1) + probs

    # 计算`continuations`的长度(包含词数量)
    lengths_to_add = tf.one_hot([config.eos_token] * config.beam_width,
                              config.vocab_size, 0, 1)
    add_mask = (1 - tf.to_int32(previously_finished))
    lengths_to_add = tf.expand_dims(add_mask, 1) * lengths_to_add
    new_prediction_lengths = tf.expand_dims(prediction_lengths,
                                          1) + lengths_to_add

    # 计算每一个beamsearch结果的得分
    scores = hyp_score(
      log_probs=total_probs,
      sequence_lengths=new_prediction_lengths,
      config=config)
      
    scores_flat = tf.reshape(scores, [-1])
    # 第一个时间步只考虑初始beam
    scores_flat = tf.cond(
      tf.convert_to_tensor(time_) > 0, lambda: scores_flat, lambda: scores[0])

    # 通过specified successors function 找到下一个beam,详细内容见下面文字。
    next_beam_scores, word_indices =  \ 
            config.choose_successors_fn(scores_flat, config)
            
    # next_beam_scores.set_shape([config.beam_width])
    word_indices.set_shape([config.beam_width])

    # 根据我们选定的预测结果,取概率值, beamid, 和状态 
    total_probs_flat = tf.reshape(total_probs, [-1], name="total_probs_flat")
    next_beam_probs = tf.gather(total_probs_flat, word_indices)
    next_beam_probs.set_shape([config.beam_width])
    next_word_ids = tf.mod(word_indices, config.vocab_size)
    next_beam_ids = tf.div(word_indices, config.vocab_size)

    # 将新的beam加入当前预测结果中 ?
    next_finished = tf.logical_or(
      tf.gather(beam_state.finished, next_beam_ids),
      tf.equal(next_word_ids, config.eos_token))
      
    # 计算下一次预测时beams的长度
    # 1. 已经终止的beam不参与计算
    # 2. 当前预测是终止符的beam不参与计算
    # 3. 还没终止的beam长度加1
    lengths_to_add = tf.to_int32(tf.not_equal(next_word_ids, config.eos_token))
    lengths_to_add = (1 - tf.to_int32(next_finished)) * lengths_to_add
    next_prediction_len = tf.gather(beam_state.lengths, next_beam_ids)
    next_prediction_len += lengths_to_add

    next_state = BeamSearchState(
      log_probs=next_beam_probs,
      lengths=next_prediction_len,
      finished=next_finished)

    output = BeamSearchStepOutput(
      scores=next_beam_scores,
      predicted_ids=next_word_ids,
      beam_parent_ids=next_beam_ids)

    return output, next_state

先说一下输入,

  • logits就是当前时刻的logits,
  • beam_state定义在这里,包含了三项内容:“log_probs”(当前时刻,所有beam取 l o g log 之后的概率值,就是可能出现哪些词), “finished”(beams是否结束,比如已经达到最大长度或者遇到了终止符), “lengths”(所有beams的长度(就是走到现在包含词个数))
  • config就是相关的参数啦

再说一下hyp_score,这个函数会增加一个长度惩罚因子,这个思想来自2016年对谷歌NMT系统研究的论文。他的想法也很简单,因为我们每次得到的分都是负的,但是我们想让总分最大,这样一来,就会鼓励那些子长度越短、包含单词数越少的句子生成。这显然不是我们想要的结果。所以我们引入了一个长度惩罚因子 α \alpha ,取值 ( 0 , 1 ) (0,1) ,对生成的句子长度进行一个规范。另外,$ \alpha 可以通过验证得到一个最佳值,一般在 [0.6,0.7]$之间,

lp(Y) =\frac{(5+|Y|)^{\alpha}}{(5+1)^{\alpha}}

choose_successors_fn定义,和相关代码 ,所以这里直接使用的是choose_top_k来找下一个beam。我们来看一下相关的函数:

def choose_top_k(scores_flat, config):
  """Chooses the top-k beams as successors.
  """
  next_beam_scores, word_indices = tf.nn.top_k(scores_flat, k=config.beam_width)
  return next_beam_scores, word_indices

2.2 step之后

接下来,会根据beamsearch的结果将所有打乱(??),然后封装结果输出。

2.3 完整step函数

  def step(self, time_, inputs, state, name=None):
    decoder_state, beam_state = state

    # Call the original decoder
    (decoder_output, decoder_state, _, _) = self.decoder.step(time_, inputs,
                                                              decoder_state)

    # Perform a step of beam search
    bs_output, beam_state = beam_search.beam_search_step(
        time_=time_,
        logits=decoder_output.logits,
        beam_state=beam_state,
        config=self.config)

    # Shuffle everything according to beam search result
    decoder_state = nest.map_structure(
        lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_state)
    decoder_output = nest.map_structure(
        lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_output)

    next_state = (decoder_state, beam_state)

    outputs = BeamDecoderOutput(
        logits=tf.zeros([self.config.beam_width, self.config.vocab_size]),
        predicted_ids=bs_output.predicted_ids,
        log_probs=beam_state.log_probs,
        scores=bs_output.scores,
        beam_parent_ids=bs_output.beam_parent_ids,
        original_outputs=decoder_output)

    finished, next_inputs, next_state = self.decoder.helper.next_inputs(
        time=time_,
        outputs=decoder_output,
        state=next_state,
        sample_ids=bs_output.predicted_ids)
    next_inputs.set_shape([self.batch_size, None])

    return (outputs, next_state, next_inputs, finished)

3. 总结

感觉beam-search有点像加了限制的BFS,限制宽度就是beam_size.
通过代码也了解很多实现方法,比如infer过程遇到提前结束的beam怎么办、比如bridge等小细节,收获还是很大的!

发布了120 篇原创文章 · 获赞 35 · 访问量 17万+

猜你喜欢

转载自blog.csdn.net/u012328476/article/details/104147178