padding_mask内容讲解

首先查看一下Embedding之中的compute_mask函数的定义

def compute_mask(self,inputs,mask=None):
    def compute_mask(self, inputs, mask=None):
        """为了适配T5,保证第一个token不被mask
        """
        if self._mode == 'embedding':
            mask = super(Embedding, self).compute_mask(inputs, mask)
            if mask is not None:
                mask1 = K.ones_like(mask[:, :1], dtype='bool')
                mask2 = mask[:, 1:]
                return K.concatenate([mask1, mask2], 1)
        else:
            return mask

这里定义的

mask = super(Embedding,self).compute_mask(inputs,mask)

然后查看call函数的定义

def call(self,inputs,mode='embedding'):

这里在这个网络层中,只有inputs,没有mask,因为这里是Embedding网络层
虽然这里没有mask,但是这里compute_mask的结果会接着网络层往下传播,实际的mask会在attention网络层之中实现

attention网络层之中实现的mask部分

前面的mask在中间的dense和layernorm网络层继续传播,到后面的attention部分终于发挥作用了
这点mask的机制内容很隐蔽,不容易被看出来
接下来我们查看MultiHeadAttention网络层的调用部分

def build(self,input_shape):
	  super(MultiHeadAttention,self).build(input_shape)
    self.q_dense = Dense(
        units=self.key_size * self.head_nums,
        use_bias=self.use_bias,
        kernel_initializer=self.kernel_initializer
    )
    self.k_dense = Dense(
        units=self.key_size * self.head_nums,
        use_bias=self.use_bias,
        kernel_initializer=self.kernel_initializer
    )
    self.v_dense = Dense(
        units=self.output_dim,
        use_bias=self.use_bias,
        kernel_initializer=self.kernel_initializer
    )

    self.combine_dense = Dense(
        units=self.output_dim,
        use_bias=self.use_bias,
        kernel_initializer=self.kernel_initializer
    )

接着网络层的call函数的调用部分

def call(self, inputs, mask=None, **kwargs):

这里的mask是一个数组,里面的内容为前面四个Dense层传播下来的mask内容
所以接下来取出来的时候

q_mask = K.cast(mask[0],K.floatx())
v_mask = K.cast(mask[2],K.floatx())

本质上mask还是前面的mask,只不过经历了四个Dense,特别的隐蔽
然后这里使用sequence_masking函数进行序列填充的操作

att = sequence_masking(att,v_mask,'add',-1)

查看sequence_masking的函数内容

def sequence_masking(x, mask, mode=0, axis=1):
    '''
    mask shape: [batch_size, seq_length]
    :param x:
    :param mask: 0,1 矩阵
    :param mode: 直接相乘 or 减大数模拟
    :param axis:
    :return:
    '''
    if mask is None or mode not in [0, 1, 'mul', 'add']:
        return x
    if axis == -1:
        axis = K.ndim(x) - 1
    assert axis > 0, 'axis must greater than 0'
    # 形状扩充
    # left
    for _ in range(axis - 1):
        mask = K.expand_dims(mask, 1)
    # right
    for _ in range(K.ndim(x) - K.ndim(mask) - axis + 1):
        mask = K.expand_dims(mask, K.ndim(mask))
    if mode in [0, 'mul']:
        return x * mask
    return x - (1 - mask) * 1e12

猜你喜欢

转载自blog.csdn.net/znevegiveup1/article/details/121248002