Bert pytorch 版本解读 之 Bert pretraining 中mask的实现

BERT Mask 方法

从Bert 论文中,我们可以知道BERT在pretrain的时候 会对训练集进行MASK 操作, 其中mask的方法是:

  1. 15%的原始数据被mask, 85% 没有被mask.
  2. 对于被mask的15% 分3种处理方式: 1) 其中80%是赋值为MASK. 2) 10%进行random 赋值,3)剩下10%保留原来值.

伯努利函数

在 hunggingface transformer 中, Bert 的mask的方法实现主要是靠torch.bernoulli()函数来完成.
首先, 介绍一下torch.bernoulli() 函数:

  • torch.bernoulli 函数 是从伯努利分布中根据input的概率抽取二元随机数(0或者1),输出与input相同大小的张量, 输出的张量的值只有0和1.
    torch.bernoulli(input, out=None):
        input(Tensor) - 输入为伯努利分布的概率值
        out(Tensor, optional)
    
    • input 输入中所有值必须在[0, 1]区间(即概率值),输出张量的第i个元素值,将以输入张量的第i个概率值等于1.
    • 返回值将会是与输入相同大小的张量,每个值为0或1

Mask 代码注释

在run_lm_finetuning.py中, 有函数 mask_tokens()

def mask_tokens(inputs, tokenizer, args):
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
    labels = inputs.clone()
    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
    """
    prob_data = torch.full(labels.shape, args.mlm_probability) 
    	是生成一个labels一样大小的矩阵,里面的值默认是0.15.
    torch.bernoulli(prob_data),从伯努利分布中抽取二元随机数(0或者1),
    	prob_data是上面产生的是一个所有值为0.15(在0和1之间的数),
    	输出张量的第i个元素值,将以输入张量的第i个概率值等于1.
    	(在这里 即输出张量的每个元素有 0.15的概率为1, 0.85的概率为0. 15%原始数据 被mask住)
    """
    masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).bool()
    """
    mask_indices通过bool()函数转成True,False
    下面对于85%原始数据 没有被mask的位置进行赋值为-1
    """
    labels[~masked_indices] = -1  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    """
    对于mask的数据,其中80%是赋值为MASK.
    这里先对所有数据以0.8概率获取伯努利分布值, 
    然后 和maksed_indices 进行与操作,得到Mask 的80%的概率 indice, 对这些位置赋值为MASK 
    """
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    """
    对于mask_indices剩下的20% 在进行提取,取其中一半进行random 赋值,剩下一般保留原来值. 
    """
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    """
    最后返回 mask之后的input 和 label.
    inputs 为原文+Mask+radom 单词
    labels 为 1 和 -1. 其中1是Mask的位置, -1是没有mask的位置
    """
    return inputs, labels
发布了28 篇原创文章 · 获赞 5 · 访问量 4333

猜你喜欢

转载自blog.csdn.net/m0_37531129/article/details/103059207