Source code to achieve Transformer Mask mechanism

Principle mask mechanism that, in the decoder side, the prediction is based on encoder information and the predicted word, and in the encoder stage, Self_Attention do not have this mechanism, is essentially a mask for the Attention of it, so we look at the Attention implementation:

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1))  / math.sqrt(d_k)
    # 这里是对应公式的  Q* K的转秩矩阵
    """
    Queries张量,形状为[B, H, L_q, D_q]
    Keys张量,形状为[B, H, L_k, D_k]
    Values张量,形状为[B, H, L_v, D_v],一般来说就是k
    """
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

We know that in training, we are batch_size units, then there will be padding, usually we take the pad == 0, then it will cause the Attention of time, query the value is 0, query a value of 0 , the value corresponding scores of our calculations is 0, it will likely lead to softmax assigned to the word is not a relatively small proportion, therefore, we will pad the corresponding score value is negative infinity , in order to reduce affect small pad that is in the above scores = scores.masked_fill(mask == 0, -1e9)mean. so we can easily imagine, in decoder, not predicted word is added to the batch with the padding of the way, so the mechanism mask when the mask mechanism and padding used is the same, It is essentially query is 0, but a different mask matrix, we can find this part of the decoder according to the code.

class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        # 对源语言与目标语言的 mask 机制
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        # Self_Attention 机制, 是针对目标语言的, 因此需要引入 tgt_mask, 这个mask 矩阵是由已预测出的单词构成的, 
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        # 这个是对 encoder 的结果的 Attention, 由于 encoder 阶段有 Padding, 所以这个 mask 矩阵和 encoder 阶段的mask 矩阵是一样的
        return self.sublayer[2](x, self.feed_forward)

Next we take a retrospective look, mask here is how come, we finally built module is Encoder_Decoder,

class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        # 将源语言的单词 embedding 放在一起, position embedding
        self.tgt_embed = tgt_embed
        # 将目标语言的单词 embedding 放在一起, position embedding
        self.generator = generator
        # 就是最后产生结果的地方

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

When we trained, using model.forward, in this part:

def run_epoch(args, data_iter, model, loss_compute, valid_params=None, epoch_num=0,
              is_valid=False, is_test=False, logger=None):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    if valid_params is not None:
        src_dict, tgt_dict, valid_iter = valid_params
        hist_valid_scores = []

    bleu_all = 0
    count_all = 0

    for i, batch in enumerate(data_iter):
        model.train()

        out = model.forward(batch.src, batch.trg ,batch.src_mask, batch.trg_mask)
        # 参数来自 batch
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        # 这一步既计算了损失, 又更新了参数
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens

These are the training step, the data is how come, mask matrix from the batch, so the most critical is the batch is how come, Stepping back even find in train.py function, we found

 _, logger_file = train_utils.run_epoch(args, (train_utils.rebatch(pad_idx, b) for b in train_iter),
                                  model_parallel if args.multi_gpu else model, train_loss_fn,
                                  valid_params=valid_params,
                                  epoch_num=epoch, logger=logger_file)

batch from rebatch function, and iterators training data, this train_iter was based on torchtext, not go into here, so the key is following function rebatch,

def rebatch(pad_idx, batch):
    "Fix order in torchtext"
    src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
    # 读的数据是 sequence * batch_size 的吗, 是在torchtext 中的Filed 决定的
    # 所以需要转换为 bacth * sequence
    return Batch(src, trg, pad_idx)

Finally found Batch class, the most critical information from here:

class Batch:
    "Object for holding a batch of data with mask during training."

    def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        # 在预测的时候是没有 tgt 的,此时为 None
        if trg is not None:
            self.trg = trg[:, :-1]
            # 每次迭代的时候, 去掉最后一个单词
            self.trg_y = trg[:, 1:]
            # 去掉第一个单词
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).sum().item()
            # target 语言中单词的个数

    @staticmethod
    def make_std_mask(tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & transformer.subsequent_mask(tgt.size(-1)).type_as(tgt_mask)
        # tgt.size(-1) 表示的是序列的长度
        return tgt_mask

In the class Batch, trg to None when well understood, that is, predict when the target language is not, in fact, at the time predicted, Batch only input, then Attention Mask forecasting process and how to achieve it ? we put this back again, look at src_mask here, mask the source language, that is, when the mask when self_Attention encoder, this is well understood, is to become a non-zero numbers, to obtain a matrix 0/1 , self.trg = trg[:, :-1]the last word here to get rid of, not a real word, but signs '<eos>', input and output are still a '<sos>' at the beginning of a sentence, self.trg_y = trg[:, 1:]remove the beginning becomes final result. access down is the most critical language acquisition target mask matrix,

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

This function did Shane?

We first wrote this:

def subsequentmask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return subsequent_mask == 0

print(subsequentmask(5))

>>

[[[ True False False False False]
  [ True  True False False False]
  [ True  True  True False False]
  [ True  True  True  True False]
  [ True  True  True  True  True]]]

When this array into numpy tensor when the configuration is the dimension (1, 5, 5) of the matrix, a middle of a sentence

Guess you like

Origin www.cnblogs.com/wevolf/p/12484972.html