tensorflow中关于BahdanauAttention以及LuongAttention实现细节

背景介绍

在 TensorFlow 中,Attention 的相关实现代码是在 tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py 文件中,这里面实现了两种 Attention 机制,分别是 BahdanauAttention 和 LuongAttention,其实现论文分别如下:
Neural Machine Translation by Jointly Learning to Align and Translate, Bahdanau, et al
Effective Approaches to Attention-based Neural Machine Translation, Luong, et al
整个 attention_wrapper.py 文件中主要包含几个类,我们主要关注其中几个:

AttentionMechanism、_BaseAttentionMechanism、LuongAttention、BahdanauAttention 实现了 Attention 机制的逻辑。
AttentionMechanism 是 Attention 类的父类,继承了 object 类,内部没有任何实现。
_BaseAttentionMechanism 继承自 AttentionMechanism 类,定义了 Attention 机制的一些公共方法实现和属性。
LuongAttention、BahdanauAttention 均继承 _BaseAttentionMechanism 类,分别实现了上面两篇论文的 Attention 机制。
AttentionWrapperState 用来存储整个计算过程中的 state,和 RNN 中的 state 类似,只不过这里额外还存储了 attention、time 等信息。
AttentionWrapper 主要用于对封装 RNNCell,继承自 RNNCell,封装后依然是 RNNCell 的实例,可以构建一个带有 Attention 机制的 Decoder。
另外还有一些公共方法,例如 hardmax、safe_cumpord 等。
下面我们以 BahdanauAttention 为例来说明 Attention 机制及 AttentionWrapper 的实现。

1.BahdanauAttention介绍

BahdanauAttention类,首先看init函数:

    def __init__(self,
        num_units,
        memory,
        memory_sequence_length=None,
        normalize=False,
        probability_fn=None,
        score_mask_value=None,
        dtype=None,
        name="BahdanauAttention"):
  • num_units:神经元节点数,我们知道在计算 eij 的时候,需要使用 si−1 和 hj 来进行计算,而二者的维度可能并不是统一的,需要进行变换和统一,所以这里就有了 Wa 和 Ua 这两个系数,所以在代码中就是用 num_units 来声明了一个全连接 Dense 网络,用于统一二者的维度,以便于下一步的计算:

    query_layer=layers_core.Dense(
        num_units, name="query_layer", use_bias=False, dtype=dtype)
    memory_layer=layers_core.Dense(
        num_units, name="memory_layer", use_bias=False, dtype=dtype)
    
  • memory:The memory to query,一般为RNN encoder的输出。维度为[batch_size, max_time, context_dim]。在父类_BaseAttentionMechanism的初始化方法中,

    with ops.name_scope(
        name, "BaseAttentionMechanismInit", nest.flatten(memory)):
        self._values = _prepare_memory(
        memory, memory_sequence_length,
        check_inner_dims_defined=check_inner_dims_defined)
        self._keys = (
        self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable
        else self._values)
    

首先是使用_prepare_memory函数对memory进行处理,然后使用上面定义的memory_layer对memory进行全连接的维度变换,变换成[batch_size, max_time, num_units]

  • memory_sequence_length:Sequence lengths for the batch entries in memory. 即 memory 变量的长度信息,类似于 dynamic_rnn 中的 sequence_length,被 _prepare_memory() 方法调用处理 memory 变量,进行 mask 操作:

    seq_len_mask = array_ops.sequence_mask(
        memory_sequence_length,
        maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
        dtype=nest.flatten(memory)[0].dtype)
    seq_len_batch_size = (
        memory_sequence_length.shape[0].value
        or array_ops.shape(memory_sequence_length)[0])
    
  • normalize:Whether to normalize the energy term. 即是否要实现标准化,方法出自论文:Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks, Salimans, et al。

  • probability_fn:A callable function which converts the score to probabilities. 计算概率时的函数,必须是一个可调用的函数,默认使用 softmax(),还可以指定 hardmax() 等函数。
  • score_mask_value:The mask value for score before passing into probability_fn. The default is -inf. Only used if memory_sequence_length is not None. 在使用 probability_fn 计算概率之前,对 score 预先进行 mask 使用的值,默认是负无穷。但这个只有在 memory_sequence_length 参数定义的时候有效。
  • dtype:The data type for the query and memory layers of the attention mechanism. 数据类型,默认是 float32。
  • name:Name to use when creating ops,自定义名称。

然后看call()函数:

    def __call__(self, query, state):
           with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
                processed_query = self.query_layer(query) if self.query_layer else query
                score = _bahdanau_score(processed_query, self._keys, self._normalize)
                alignments = self._probability_fn(score, state)
                next_state = alignments
                return alignments, next_state

call函数首先对query进行全连接层的维度变换,然后调用_bahdanau_score函数计算score,也就是eij,然后调用_probability_fn函数计算softmax.

  • 在_bahdanau_score函数中,_key函数表示Encoder的输出,也即是memory的变换后的值。procesed_query值为decoder 隐藏层。_bahdanau_score函数部分代码如下所示:

    if normalize:
           # Scalar used in weight normalization
           g = variable_scope.get_variable(
           "attention_g", dtype=dtype,
            initializer=init_ops.constant_initializer(math.sqrt((1. / num_units))),
                                                                           shape=())
            # Bias added prior to the nonlinearity
            b = variable_scope.get_variable("attention_b", [num_units], dtype=dtype,
                                                                 initializer=init_ops.zeros_initializer())
            # normed_v = g * v / ||v||
             normed_v = g * v * math_ops.rsqrt(math_ops.reduce_sum(math_ops.square(v)))
             return math_ops.reduce_sum(normed_v * math_ops.tanh(keys + processed_query + b), [2])
    else:
             return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2])
    

从代码中可以看出,_bahdanau_score函数主要有两个作用,一个是计算eij,另一个是对eij进行weighted normalization处理。
这里写图片描述

score计算的方式有点类似concat的方式。

  • _probability_fn函数如果不直接指定的话,默认的值为softmax函数。

2.LuongAttention介绍

与BahdanauAttention相比,LuongAttention在具体实现上相似,只是在代码细节上略有不同。下面进行详细的介绍:

  • 首先,在init函数中,只是简单的定义了memory_layer,代码如下所下所示:

         super(LuongAttention, self).__init__(
                query_layer=None,
                memory_layer=layers_core.Dense(
                num_units, name="memory_layer", use_bias=False, dtype=dtype),
                memory=memory,
                probability_fn=wrapped_probability_fn,
                memory_sequence_length=memory_sequence_length,
                score_mask_value=score_mask_value,
                name=name)
    
  • 其次,在call函数中,结构相似,主要区别是将socre函数变成了_luong_score函数。

  • 最后,在_luong_score函数中,主要代码如下:

    score = math_ops.matmul(query, keys, transpose_b=True)
    score = array_ops.squeeze(score, [1])
    
    if scale:
            # Scalar used in weight scaling
            g = variable_scope.get_variable(
                  "attention_g", dtype=dtype,
            initializer=init_ops.ones_initializer, shape=())
            score = g * score
    

这里实现的是简单的相乘的方式。不过需要注意的一点是,在attention的父类_BaseAttentionMechanism中,已经对self._values值进行dense处理,处理后的结果就是key。

相关链接:https://cuiqingcai.com/5873.html

发布了98 篇原创文章 · 获赞 337 · 访问量 48万+

猜你喜欢

转载自blog.csdn.net/yiyele/article/details/81393229
今日推荐