- 主要看sequence_layers.py这个脚本中才是实现了attention+decoder的部分,model中只是个架子。
- sequence_layer中也是直接调用了TF的api,如果想深入理解,还是需要看TF源码。先从sequence_layer入手。
AttentionWithAutoregression继承了Attention,Attention继承了SequenceLayerBase。其中create_logits调用的就是SequenceLayerBase中的,create_logits中的unroll_cell调用的是Attention里面的。
def create_logits(self):
"""Creates character sequence logits for a net specified in the constructor.
A "main" method for the sequence layer which glues together all pieces.
Returns:
A tensor with shape [batch_size, seq_length, num_char_classes].
"""
with tf.variable_scope('LSTM'):
first_label = self.get_input(prev=None, i=0)
decoder_inputs = [first_label] + [None] * (self._params.seq_length - 1)
# 是全0矩阵吗?
lstm_cell = tf.contrib.rnn.LSTMCell(
self._mparams.num_lstm_units,
# 就是输出的size
use_peepholes=False,
# 采用的是最早提出的LSTM的构造,1997年
cell_clip=self._mparams.lstm_state_clip_value,
state_is_tuple=True,
initializer=orthogonal_initializer)
# 构建了decoder的LSTM结构
lstm_outputs, _ = self.unroll_cell(
decoder_inputs=decoder_inputs,
initial_state=lstm_cell.zero_state(self._batch_size, tf.float32),
loop_function=self.get_input,
cell=lstm_cell)
# 调用了ATTENTION方法中的
with tf.variable_scope('logits'):
logits_list = [
tf.expand_dims(self.char_logit(logit, i), dim=1)
for i, logit in enumerate(lstm_outputs)
]
return tf.concat(logits_list, 1)
class ATTENTION():
def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
return tf.contrib.legacy_seq2seq.attention_decoder(
decoder_inputs=decoder_inputs,
initial_state=initial_state,
# 这两个参数的意义还咩有理解透彻
attention_states=self._net,
# _net中是CNN输出后和空间one-hot特征concat的特征
cell=cell,
loop_function=self.get_input)
https://github.com/tensorflow/models/blob/master/research/attention_ocr/python/sequence_layers.py
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------
tf.contrib.legacy_seq2seq.attention_decoder源码解读:
如果是想理解attention的机制以及实现,还是需要读到TF源码中的,毕竟到这一部分的代码还是没有涉及到真正的模型实现。
我这里看源码是TF1.10的,这是因为attention_ocr源码中用的就是这个。其实是可以看更新一点的源码的。