预训练语言模型mask函数DataCollatorForLanguageModeling和DataCollatorForWholeWordMask解析

预训练语言模型中的非常重要的任务是MLM任务,MLM任务需要对原始文本进行mask。
transformers库已经集成了预训练语言模型中的mask机制,这里分析其中的两个函数DataCollatorForLanguageModelingDataCollatorForWholeWordMask

1.1DataCollatorForLanguageModeling

这个类实现了Bert模型中的MLM任务中提出的mask机制。下面我对照transformers库中的原始代码讲解。

class DataCollatorForLanguageModeling:
    """
    Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
    are not all of the same length.

    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
            Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
            inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
            non-masked tokens and the value to predict for the masked token.
        mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
            The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.

    .. note::

        For best performance, this data collator should be used with a dataset having items that are dictionaries or
        BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
        :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
        argument :obj:`return_special_tokens_mask=True`.
    """

    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15
    pad_to_multiple_of: Optional[int] = None

    def __post_init__(self):
        if self.mlm and self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary for masked language modeling. "
                "You should pass `mlm=False` to train on causal language modeling instead."
            )

    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], (dict, BatchEncoding)):
            batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {
    
    "input_ids": _collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}

        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        if self.mlm:
            batch["input_ids"], batch["labels"] = self.mask_tokens(
                batch["input_ids"], special_tokens_mask=special_tokens_mask
            )
        else:
            labels = batch["input_ids"].clone()
            if self.tokenizer.pad_token_id is not None:
                labels[labels == self.tokenizer.pad_token_id] = -100
            batch["labels"] = labels
        return batch

    def mask_tokens(
        self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        '''
        inputs必须为tensor类型,
        '''
        
        # 这里的labels指的是哪些字要被mask
        labels = inputs.clone()   
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        # 生成一个全为0.15 的矩阵,维度和inputs的大小一样。
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        # 这里的spcial_tokens指的是cls,sep,pad等,special_tokens_mask的维度和inputs一样,
        # 特殊token的位置是true,其他位置是false
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()
        # masked_fill_函数的作用是会根据special_tokens_mask中true的位置,将probability_matrix中对于的位置置为0
        # 前面special_tokens_mask已经将cls等特殊字符的位置置为true,也就是通过将probability_matrix中特殊字符的位置置0,从而
        # 不会mask掉特殊字符。
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        # Bert是将句子中15%的词随机mask掉,也就是每个词有15%的概率被mask掉(1),有85%的概率不被mask掉(0),
        # 因此可以建模为伯努利分布,参数p=0.15,
        masked_indices = torch.bernoulli(probability_matrix).bool() 
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

1.2 用法示例:

from transformers.data.data_collator import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask

path = '../PTM/bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(path)
model = BertForMaskedLM.from_pretrained(path)

sent = "我爱北京天安门,天安门上太阳升"

# 创建一个实例,参数是tokenizer
datacollecter = DataCollatorForLanguageModeling(tokenizer)

encoded_dict = tokenizer.encode_plus(
            sent,  # 输入文本
            add_special_tokens=True,  # 添加 '[CLS]' 和 '[SEP]'
            max_length=32,  # 填充 & 截断长度
            truncation=True,
            pad_to_max_length=True,
            return_attention_mask=True,  # 返回 attn. masks.
        )

print(encoded_dict)
input_ids = [torch.tensor(encoded_dict['input_ids'])]
# 传入的参数是tensor形式的input_ids,返回input_ids和label,label中
# -100的位置的词没有被mask,
output = datacollecter(input_ids)  

print(output)

'''
{'input_ids': tensor([[ 101, 2769, 4263, 1266,  776, 1921,  103, 7305, 8024, 1921, 2128, 7305,
          677, 1922, 7345, 1285,  102,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]]), 'labels': tensor([[-100, -100, -100, -100, -100, -100, 2128(被mask掉的词)
            , -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100]])}

'''

2.1 DataCollatorForWholeWordMask

这个类实现了whole word mask机制,并且继承自DataCollatorForLanguageModeling。对于中文来说,要和prepare_ref函数一起使用。prepare_ref函数也是transformer官方实现的一个针对中文情况的函数,其作用是返回子词位置。

def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer):

    '''
    sent = "我爱北京天安门,天安门上太阳升"
    bert_tokens = ['[CLS]','我', '爱', '北', '##京', '天', '##安', '##门', ',', '天', '##安', '##门', '上', '太', '##阳', '##升',]
    
    return: [4,  6,  7,  10,  11, 14, 15]
             京,安, 门, 安,  门, 阳,升
    '''

    ltp_res = []

    for i in range(0, len(lines), 100):
        res = ltp_tokenizer.seg(lines[i : i + 100])[0]
        res = [get_chinese_word(r) for r in res]
        ltp_res.extend(res)
    assert len(ltp_res) == len(lines)

    bert_res = []
    for i in range(0, len(lines), 100):
        res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512)
        bert_res.extend(res["input_ids"])
    assert len(bert_res) == len(lines)

    ref_ids = []
    for input_ids, chinese_word in zip(bert_res, ltp_res):

        input_tokens = []
        for id in input_ids:
            token = bert_tokenizer._convert_id_to_token(id)
            input_tokens.append(token)
        input_tokens = add_sub_symbol(input_tokens, chinese_word)
        ref_id = []
        # We only save pos of chinese subwords start with ##, which mean is part of a whole word.
        for i, token in enumerate(input_tokens):
            if token[:2] == "##":
                clean_token = token[2:]
                # save chinese tokens' pos
                if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)):
                    ref_id.append(i)
        ref_ids.append(ref_id)

    assert len(ref_ids) == len(bert_res)

    return ref_ids
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
    """
    Data collator used for language modeling.

    - collates batches of tensors, honoring their tokenizer's pad_token
    - preprocesses batches for masked language modeling
    """

    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        if isinstance(examples[0], (dict, BatchEncoding)):
            input_ids = [e["input_ids"] for e in examples]
        else:
            input_ids = examples
            examples = [{
    
    "input_ids": e} for e in examples]

        batch_input = _collate_batch(input_ids, self.tokenizer)

        mask_labels = []
        for e in examples:
            ref_tokens = []
            for id in tolist(e["input_ids"]):
                token = self.tokenizer._convert_id_to_token(id)
                ref_tokens.append(token)

            # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
            if "chinese_ref" in e:
                ref_pos = tolist(e["chinese_ref"])
                len_seq = len(e["input_ids"])
                for i in range(len_seq):
                    if i in ref_pos:
                        ref_tokens[i] = "##" + ref_tokens[i]
            mask_labels.append(self._whole_word_mask(ref_tokens))
        batch_mask = _collate_batch(mask_labels, self.tokenizer)
        inputs, labels = self.mask_tokens(batch_input, batch_mask)
        return {
    
    "input_ids": inputs, "labels": labels}

    def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
        """
        Get 0/1 labels for masked tokens with whole word mask proxy
        """

        cand_indexes = []
        for (i, token) in enumerate(input_tokens):
            if token == "[CLS]" or token == "[SEP]":
                continue

            if len(cand_indexes) >= 1 and token.startswith("##"):
                cand_indexes[-1].append(i)
            else:
                cand_indexes.append([i])

        random.shuffle(cand_indexes)
        # 这里的mask 15%的也是字,而不是15%的词
        # 只不过mask的时候是根据词来mask,但是统计是否达到了15%是根据字来统计
        num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
        masked_lms = []
        covered_indexes = set()
        for index_set in cand_indexes:
            if len(masked_lms) >= num_to_predict:
                break
            # If adding a whole-word mask would exceed the maximum number of
            # predictions, then just skip this candidate.
            if len(masked_lms) + len(index_set) > num_to_predict:
                continue
            is_any_index_covered = False
            for index in index_set:
                if index in covered_indexes:
                    is_any_index_covered = True
                    break
            if is_any_index_covered:
                continue
            for index in index_set:
                covered_indexes.add(index)  # 统计是否达到了mask 15%的字的比例,而不是词
                masked_lms.append(index)

        assert len(covered_indexes) == len(masked_lms)
        mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
        return mask_labels

    def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
        'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
        """

        if self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
            )
        labels = inputs.clone()
        # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

        # 这里的mask_labels就相当于已经实现了15%的字MASK之后的masked_indices(DataCollatorForLanguageModeling类中
        # 根据伯努利分布生成的)
        probability_matrix = mask_labels

        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        ]
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
        if self.tokenizer._pad_token is not None:
            padding_mask = labels.eq(self.tokenizer.pad_token_id)
            probability_matrix.masked_fill_(padding_mask, value=0.0)

        masked_indices = probability_matrix.bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

2.2 用法示例

from transformers.data.data_collator import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask

path = '../PTM/bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(path)
model = BertForMaskedLM.from_pretrained(path)

# seed_val = 42
# random.seed(seed_val)
# np.random.seed(seed_val)
# torch.manual_seed(seed_val)
# torch.cuda.manual_seed_all(seed_val)


'''
这个方法也是bert自带的实现了随机mask 15%单词任务的方法。
由于我是要在领域内的数据集上微调bert,所以还是用language model的训练方法。
'''


from prepare import prepare_ref
from ltp import LTP
ltp = LTP()

sent = "我爱北京天安门,天安门上太阳升"
ref = prepare_ref([sent], ltp, tokenizer)[0]
print(ref)

encoded_dict = tokenizer.encode_plus(
            sent,  # 输入文本
            add_special_tokens=True,  # 添加 '[CLS]' 和 '[SEP]'
            max_length=32,  # 填充 & 截断长度
            truncation=True,
            pad_to_max_length=True,
            return_attention_mask=True,  # 返回 attn. masks.
            # return_tensors='pt'
        )

datacollecter = DataCollatorForWholeWordMask(tokenizer)

# 加上子字信息,而且传入的是List,不是tensor。
features = [{
    
    'input_ids':encoded_dict['input_ids'], 'chinese_ref':ref}]
tmp = datacollecter(features)
print(tmp)

'''
{'input_ids': tensor([[ 101, 2769, 4263, 1266,  776, 1921, 2128, 7305, 8024,  103,  103,  103,
          677, 1922, 7345, 1285,  102,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]]), 'labels': tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, 1921, 2128, 7305(第二个“天安门”三个字被随机替换了),
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100]])}
'''

猜你喜欢

转载自blog.csdn.net/mch2869253130/article/details/123140141