NLP practice - use your own corpus for Mask Language Model pre-training

1. About MLM

1.1 What is MLM

As Bertone of the two major tasks of pre-training, MLM and NSP should be familiar to everyone. Among them, the NSP task is often rejected in some subsequent pre-training tasks. For example, the NSP task is directly abandoned Robertain the middle , and the NSP is replaced by the sentence sequence prediction. This is because NSP is too simple as a classification task and does not help the model learning much, while MLM is retained by most pre-training models. It can also be proved by the experimental results that the main ability of MLM should come from the training of MLM tasks.AlbertRobertaBert

BertThe representative pre-trained language model is trained on the basis of large-scale predictions to obtain basic learning capabilities. In practical applications, the predictions we face may have some particularities, which makes it necessary to re-train MLM .

1.2 How to conduct MLM training

The training of MLM is actually different in different pre-training models. The content introduced today takes the most basic Bert as an example. Bert's MLM is a static mask, and in other subsequent pre-training models, this strategy is usually replaced by a dynamic mask. In addition to the whole word mask model, these are not within the scope of today's discussion.

The task of the so-called mask language model, generally speaking, is to replace part of the token in the sentence, and then try to restore the masked token according to the rest of the sentence.

The ratio of the mask is generally 15%, and this ratio is also inherited by most subsequent models, but in the original BERT paper, no specific explanation was given for the definition of this ratio. In my impression, it seems that in the paper that was also proposed by Google later on the T5 model, this was explained, the ratio of the mask was experimented, and finally concluded that the ratio of 15% is the most reasonable (if I remember wrong, please correct me).

After the selected 15%tokens are selected, not all of them are replaced with [mask] tags, but from the 15%selected part, replace them 80%with [mask] 10%and a random token, and 10%keep the rest of the original token. Doing so improves the robustness of the model. This ratio can also be controlled by yourself.

At this point, some students may ask, since 10% remains unchanged, why not just choose 15%*90% = 13.5% token? If you read the following code, you will clearly understand the problem.

Because the task of MLM is to predict all the selected 15% tokens, regardless of whether the token is replaced with [mask], that is to say, even if it is kept as it is, it still needs to be predicted.

After introducing the basic content, I will introduce how to train the mask language model based on the transformers module.

2. Code section

In fact transformers, the module itself provides MLM training tasks, and the model has been written. You only need to call its built-in trainer and datasetsmodules. Interested students can go to huggingface's official website to search for related tutorials.

However, I feel that I datasetshave to write the py file of the data set every time I call it. If I am arrownot familiar with the data format, it is easy to make mistakes, and I think the trainer is not very easy to use. Any small modification will be very laborious (it thinks it is well written and considers all the needs of users, but in fact there are some redundant parts).

So I refer to its implementation, disassemble its code, and reorganize it in my own way.

2.1 Preparations

First of all, before writing the core code, make preparations.
import all required modules:

import os
import json
import copy
from tqdm.notebook import tqdm

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import BertForMaskedLM, BertTokenizerFast

Write a config class to gather all parameters:

class Config:
    def __init__(self):
        pass
    
    def mlm_config(
        self, 
        mlm_probability=0.15, 
        special_tokens_mask=None,
        prob_replace_mask=0.8,
        prob_replace_rand=0.1,
        prob_keep_ori=0.1,
    ):
        """
        :param mlm_probability: 被mask的token总数
        :param special_token_mask: 特殊token
        :param prob_replace_mask: 被替换成[MASK]的token比率
        :param prob_replace_rand: 被随机替换成其他token比率
        :param prob_keep_ori: 保留原token的比率
        """
        assert sum([prob_replace_mask, prob_replace_rand, prob_keep_ori]) == 1,                 ValueError("Sum of the probs must equal to 1.")
        self.mlm_probability = mlm_probability
        self.special_tokens_mask = special_tokens_mask
        self.prob_replace_mask = prob_replace_mask
        self.prob_replace_rand = prob_replace_rand
        self.prob_keep_ori = prob_keep_ori
        
    def training_config(
        self,
        batch_size,
        epochs,
        learning_rate,
        weight_decay,
        device,
    ):
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.device = device
        
    def io_config(
        self,
        from_path,
        save_path,
    ):
        self.from_path = from_path
        self.save_path = save_path

Then set various configurations:

config = Config()
config.mlm_config()
config.training_config(batch_size=4, epochs=10, learning_rate=1e-5, weight_decay=0, device='cuda:0')
config.io_config(from_path='/data/BERTmodels/huggingface/chinese_wwm/', 
                 save_path='./finetune_embedding_model/mlm/')

Then create the BERT model. Note that the tokenizer here is an ordinary tokenizer, and the BERT model is BertForMaskedLM with downstream tasks. It is a class written in transformers.

bert_tokenizer = BertTokenizerFast.from_pretrained(config.from_path)
bert_mlm_model = BertForMaskedLM.from_pretrained(config.from_path)

2.2 Dataset

Since this package was abandoned datasets, now we need to implement the data input ourselves. The solution is to use torch Datasetclasses. This class DataLoaderis generally used together with an aggregate function to organize batches when it is built. And I'm being lazy here, so I didn't write the aggregation function, and the batch organization method is placed in the dataset.

In this class, there is a method of mask tokens, which is to select all the tokens that need a mask from the data, and use one of the three mask methods. This method is taken from the transformers. After converting it from a class method to a static method test, we put it in our own class for our use. Read this piece of code carefully, and you can answer the question raised in 1.2.

The principle of taking a batch is very simple. At the beginning, we deepcopy the original data, and then intercept a batch size from it each time. At this time, the current data is one less batch. We define the length of this class as the current length divided by the batch size and round it down. Therefore, when the length of the class becomes 0, it means that all the steps of this epoch have been executed, and the next epoch is about to be trained. .

class TrainDataset(Dataset):
    """
    注意:由于没有使用data_collator,batch放在dataset里边做,
    因而在dataloader出来的结果会多套一层batch维度,传入模型时注意squeeze掉
    """
    def __init__(self, input_texts, tokenizer, config):
        self.input_texts = input_texts
        self.tokenizer = tokenizer
        self.config = config
        self.ori_inputs = copy.deepcopy(input_texts)
        
    def __len__(self):
        return len(self.input_texts) // self.config.batch_size
    
    def __getitem__(self, idx):
        batch_text = self.input_texts[: self.config.batch_size]
        features = self.tokenizer(batch_text, max_length=512, truncation=True, padding=True, return_tensors='pt')
        inputs, labels = self.mask_tokens(features['input_ids'])
        batch = {
    
    "inputs": inputs, "labels": labels}
        self.input_texts = self.input_texts[self.config.batch_size: ]
        if not len(self):
            self.input_texts = self.ori_inputs
        
        return batch
        
    def mask_tokens(self, inputs):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.config.mlm_probability)
        if self.config.special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = self.config.special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, self.config.prob_replace_mask)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        current_prob = self.config.prob_replace_rand / (1 - self.config.prob_replace_mask)
        indices_random = torch.bernoulli(torch.full(labels.shape, current_prob)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

Then take some corpus for training, the format is very simple, that is, put all the text in a list, pay attention to the length not exceeding 512 tokens, otherwise the extra part will be wasted. Appropriate preprocessing can be done.

[
	"这是一条文本",
	"这是另一条文本",
	...,
]

Then build the dataloader:

train_dataset = TrainDataset(training_texts, bert_tokenizer, config)
train_dataloader = DataLoader(train_dataset)

2.3 Training

Build a training method, the input parameters are our instantiated model to be trained, data set, and config:

def train(model, train_dataloader, config):
    """
    训练
    :param model: nn.Module
    :param train_dataloader: DataLoader
    :param config: Config
    ---------------
    ver: 2021-11-08
    by: changhongyu
    """
    assert config.device.startswith('cuda') or config.device == 'cpu', ValueError("Invalid device.")
    device = torch.device(config.device)
    
    model.to(device)
    
    if not len(train_dataloader):
        raise EOFError("Empty train_dataloader.")
        
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
    
    "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
        {
    
    "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]
    
    optimizer = AdamW(params=optimizer_grouped_parameters, lr=config.learning_rate, weight_decay=config.weight_decay)
    
    for cur_epc in tqdm(range(int(config.epochs)), desc="Epoch"):
        training_loss = 0
        print("Epoch: {}".format(cur_epc+1))
        model.train()
        for step, batch in enumerate(tqdm(train_dataloader, desc='Step')):
            input_ids = batch['inputs'].squeeze(0).to(device)
            labels = batch['labels'].squeeze(0).to(device)
            loss = model(input_ids=input_ids, labels=labels).loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.zero_grad()
            training_loss += loss.item()
        print("Training loss: ", training_loss)

Call it for training rounds:

train(model=bert_mlm_model, train_dataloader=train_dataloader, config=config)

2.4 Save and load

Students who have used the pre-trained model should know that ordinary bert has two outputs, namely the 768-dimensional encoding result corresponding to each token, and the sentence features used to represent the entire sentence.

And this sentence feature is obtained by pooling the original sentence by a Pooler module in the model. However, the training of this Pooler is not from the MLM task, but from the NSP task.

Since there is no NSP task, the Pooler cannot be trained, so there is no need to add the Pooler to the model.

Therefore, you need to save the embedding and encoder separately when saving,
and read the embedding and encoder separately when loading, so that the trained model cannot get the sentence representation of the CLS layer. If necessary, you can manually pool it.

torch.save(bert_mlm_model.bert.embeddings.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_eb.bin'.format(config.epochs)))
torch.save(bert_mlm_model.bert.encoder.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_ec.bin'.format(config.epochs)))

For loading, after instantiating the bert model, use bert's embedding component and encoder component to read the two weight files respectively.

At this point, the content of this issue is all over. I hope that students who have read this blog can have a deeper understanding of the basic principles of Bert.

Guess you like

Origin blog.csdn.net/weixin_44826203/article/details/121439850