padding_mask的内容讲解
首先查看一下Embedding之中的compute_mask函数的定义
def compute_mask(self,inputs,mask=None):
def compute_mask(self, inputs, mask=None):
"""为了适配T5,保证第一个token不被mask
"""
if self._mode == 'embedding':
mask = super(Embedding, self).compute_mask(inputs, mask)
if mask is not None:
mask1 = K.ones_like(mask[:, :1], dtype='bool')
mask2 = mask[:, 1:]
return K.concatenate([mask1, mask2], 1)
else:
return mask
这里定义的
mask = super(Embedding,self).compute_mask(inputs,mask)
然后查看call函数的定义
def call(self,inputs,mode='embedding'):
这里在这个网络层中,只有inputs,没有mask,因为这里是Embedding网络层
虽然这里没有mask,但是这里compute_mask的结果会接着网络层往下传播,实际的mask会在attention网络层之中实现
attention网络层之中实现的mask部分
前面的mask在中间的dense和layernorm网络层继续传播,到后面的attention部分终于发挥作用了
这点mask的机制内容很隐蔽,不容易被看出来
接下来我们查看MultiHeadAttention网络层的调用部分
def build(self,input_shape):
super(MultiHeadAttention,self).build(input_shape)
self.q_dense = Dense(
units=self.key_size * self.head_nums,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer
)
self.k_dense = Dense(
units=self.key_size * self.head_nums,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer
)
self.v_dense = Dense(
units=self.output_dim,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer
)
self.combine_dense = Dense(
units=self.output_dim,
use_bias=self.use_bias,
kernel_initializer=self.kernel_initializer
)
接着网络层的call函数的调用部分
def call(self, inputs, mask=None, **kwargs):
这里的mask是一个数组,里面的内容为前面四个Dense层传播下来的mask内容
所以接下来取出来的时候
q_mask = K.cast(mask[0],K.floatx())
v_mask = K.cast(mask[2],K.floatx())
本质上mask还是前面的mask,只不过经历了四个Dense,特别的隐蔽
然后这里使用sequence_masking函数进行序列填充的操作
att = sequence_masking(att,v_mask,'add',-1)
查看sequence_masking的函数内容
def sequence_masking(x, mask, mode=0, axis=1):
'''
mask shape: [batch_size, seq_length]
:param x:
:param mask: 0,1 矩阵
:param mode: 直接相乘 or 减大数模拟
:param axis:
:return:
'''
if mask is None or mode not in [0, 1, 'mul', 'add']:
return x
if axis == -1:
axis = K.ndim(x) - 1
assert axis > 0, 'axis must greater than 0'
# 形状扩充
# left
for _ in range(axis - 1):
mask = K.expand_dims(mask, 1)
# right
for _ in range(K.ndim(x) - K.ndim(mask) - axis + 1):
mask = K.expand_dims(mask, K.ndim(mask))
if mode in [0, 'mul']:
return x * mask
return x - (1 - mask) * 1e12