attention_ocr源码

  1. 主要看sequence_layers.py这个脚本中才是实现了attention+decoder的部分,model中只是个架子。
  2. sequence_layer中也是直接调用了TF的api,如果想深入理解,还是需要看TF源码。先从sequence_layer入手。

AttentionWithAutoregression继承了Attention,Attention继承了SequenceLayerBase。其中create_logits调用的就是SequenceLayerBase中的,create_logits中的unroll_cell调用的是Attention里面的。

def create_logits(self):
    """Creates character sequence logits for a net specified in the constructor.
    A "main" method for the sequence layer which glues together all pieces.
    Returns:
      A tensor with shape [batch_size, seq_length, num_char_classes].
    """
    with tf.variable_scope('LSTM'):
      first_label = self.get_input(prev=None, i=0)
      decoder_inputs = [first_label] + [None] * (self._params.seq_length - 1)
        # 是全0矩阵吗?
      lstm_cell = tf.contrib.rnn.LSTMCell(
          self._mparams.num_lstm_units,
            # 就是输出的size
          use_peepholes=False,
            # 采用的是最早提出的LSTM的构造,1997年
          cell_clip=self._mparams.lstm_state_clip_value,
          state_is_tuple=True,
          initializer=orthogonal_initializer)
        # 构建了decoder的LSTM结构
      lstm_outputs, _ = self.unroll_cell(
          decoder_inputs=decoder_inputs,
          initial_state=lstm_cell.zero_state(self._batch_size, tf.float32),
          loop_function=self.get_input,
          cell=lstm_cell)
        # 调用了ATTENTION方法中的
    with tf.variable_scope('logits'):
      logits_list = [
          tf.expand_dims(self.char_logit(logit, i), dim=1)
          for i, logit in enumerate(lstm_outputs)
      ]

    return tf.concat(logits_list, 1)


class ATTENTION():

  def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
    return tf.contrib.legacy_seq2seq.attention_decoder(
        decoder_inputs=decoder_inputs,
        initial_state=initial_state,
        # 这两个参数的意义还咩有理解透彻
        attention_states=self._net,
        # _net中是CNN输出后和空间one-hot特征concat的特征
        cell=cell,
        loop_function=self.get_input)

https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/sequence_layers.py

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------

tf.contrib.legacy_seq2seq.attention_decoder源码解读:

如果是想理解attention的机制以及实现,还是需要读到TF源码中的,毕竟到这一部分的代码还是没有涉及到真正的模型实现。

我这里看源码是TF1.10的,这是因为attention_ocr源码中用的就是这个。其实是可以看更新一点的源码的。

发布了45 篇原创文章 · 获赞 1 · 访问量 8571

猜你喜欢

转载自blog.csdn.net/qq_32110859/article/details/100514625
OCR
今日推荐