NLP practice - use your own corpus for Mask Language Model pre-training
1. About MLM
1.1 What is MLM
As Bert
one of the two major tasks of pre-training, MLM and NSP should be familiar to everyone. Among them, the NSP task is often rejected in some subsequent pre-training tasks. For example, the NSP task is directly abandoned Roberta
in the middle , and the NSP is replaced by the sentence sequence prediction. This is because NSP is too simple as a classification task and does not help the model learning much, while MLM is retained by most pre-training models. It can also be proved by the experimental results that the main ability of MLM should come from the training of MLM tasks.Albert
Roberta
Bert
Bert
The representative pre-trained language model is trained on the basis of large-scale predictions to obtain basic learning capabilities. In practical applications, the predictions we face may have some particularities, which makes it necessary to re-train MLM .
1.2 How to conduct MLM training
The training of MLM is actually different in different pre-training models. The content introduced today takes the most basic Bert as an example. Bert's MLM is a static mask, and in other subsequent pre-training models, this strategy is usually replaced by a dynamic mask. In addition to the whole word mask model, these are not within the scope of today's discussion.
The task of the so-called mask language model, generally speaking, is to replace part of the token in the sentence, and then try to restore the masked token according to the rest of the sentence.
The ratio of the mask is generally 15%
, and this ratio is also inherited by most subsequent models, but in the original BERT paper, no specific explanation was given for the definition of this ratio. In my impression, it seems that in the paper that was also proposed by Google later on the T5 model, this was explained, the ratio of the mask was experimented, and finally concluded that the ratio of 15% is the most reasonable (if I remember wrong, please correct me).
After the selected 15%
tokens are selected, not all of them are replaced with [mask] tags, but from the 15%
selected part, replace them 80%
with [mask] 10%
and a random token, and 10%
keep the rest of the original token. Doing so improves the robustness of the model. This ratio can also be controlled by yourself.
At this point, some students may ask, since 10% remains unchanged, why not just choose 15%*90% = 13.5% token? If you read the following code, you will clearly understand the problem.
Because the task of MLM is to predict all the selected 15% tokens, regardless of whether the token is replaced with [mask], that is to say, even if it is kept as it is, it still needs to be predicted.
After introducing the basic content, I will introduce how to train the mask language model based on the transformers module.
2. Code section
In fact transformers
, the module itself provides MLM training tasks, and the model has been written. You only need to call its built-in trainer and datasets
modules. Interested students can go to huggingface's official website to search for related tutorials.
However, I feel that I datasets
have to write the py file of the data set every time I call it. If I am arrow
not familiar with the data format, it is easy to make mistakes, and I think the trainer is not very easy to use. Any small modification will be very laborious (it thinks it is well written and considers all the needs of users, but in fact there are some redundant parts).
So I refer to its implementation, disassemble its code, and reorganize it in my own way.
2.1 Preparations
First of all, before writing the core code, make preparations.
import all required modules:
import os
import json
import copy
from tqdm.notebook import tqdm
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import BertForMaskedLM, BertTokenizerFast
Write a config class to gather all parameters:
class Config:
def __init__(self):
pass
def mlm_config(
self,
mlm_probability=0.15,
special_tokens_mask=None,
prob_replace_mask=0.8,
prob_replace_rand=0.1,
prob_keep_ori=0.1,
):
"""
:param mlm_probability: 被mask的token总数
:param special_token_mask: 特殊token
:param prob_replace_mask: 被替换成[MASK]的token比率
:param prob_replace_rand: 被随机替换成其他token比率
:param prob_keep_ori: 保留原token的比率
"""
assert sum([prob_replace_mask, prob_replace_rand, prob_keep_ori]) == 1, ValueError("Sum of the probs must equal to 1.")
self.mlm_probability = mlm_probability
self.special_tokens_mask = special_tokens_mask
self.prob_replace_mask = prob_replace_mask
self.prob_replace_rand = prob_replace_rand
self.prob_keep_ori = prob_keep_ori
def training_config(
self,
batch_size,
epochs,
learning_rate,
weight_decay,
device,
):
self.batch_size = batch_size
self.epochs = epochs
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.device = device
def io_config(
self,
from_path,
save_path,
):
self.from_path = from_path
self.save_path = save_path
Then set various configurations:
config = Config()
config.mlm_config()
config.training_config(batch_size=4, epochs=10, learning_rate=1e-5, weight_decay=0, device='cuda:0')
config.io_config(from_path='/data/BERTmodels/huggingface/chinese_wwm/',
save_path='./finetune_embedding_model/mlm/')
Then create the BERT model. Note that the tokenizer here is an ordinary tokenizer, and the BERT model is BertForMaskedLM with downstream tasks. It is a class written in transformers.
bert_tokenizer = BertTokenizerFast.from_pretrained(config.from_path)
bert_mlm_model = BertForMaskedLM.from_pretrained(config.from_path)
2.2 Dataset
Since this package was abandoned datasets
, now we need to implement the data input ourselves. The solution is to use torch Dataset
classes. This class DataLoader
is generally used together with an aggregate function to organize batches when it is built. And I'm being lazy here, so I didn't write the aggregation function, and the batch organization method is placed in the dataset.
In this class, there is a method of mask tokens, which is to select all the tokens that need a mask from the data, and use one of the three mask methods. This method is taken from the transformers. After converting it from a class method to a static method test, we put it in our own class for our use. Read this piece of code carefully, and you can answer the question raised in 1.2.
The principle of taking a batch is very simple. At the beginning, we deepcopy the original data, and then intercept a batch size from it each time. At this time, the current data is one less batch. We define the length of this class as the current length divided by the batch size and round it down. Therefore, when the length of the class becomes 0, it means that all the steps of this epoch have been executed, and the next epoch is about to be trained. .
class TrainDataset(Dataset):
"""
注意:由于没有使用data_collator,batch放在dataset里边做,
因而在dataloader出来的结果会多套一层batch维度,传入模型时注意squeeze掉
"""
def __init__(self, input_texts, tokenizer, config):
self.input_texts = input_texts
self.tokenizer = tokenizer
self.config = config
self.ori_inputs = copy.deepcopy(input_texts)
def __len__(self):
return len(self.input_texts) // self.config.batch_size
def __getitem__(self, idx):
batch_text = self.input_texts[: self.config.batch_size]
features = self.tokenizer(batch_text, max_length=512, truncation=True, padding=True, return_tensors='pt')
inputs, labels = self.mask_tokens(features['input_ids'])
batch = {
"inputs": inputs, "labels": labels}
self.input_texts = self.input_texts[self.config.batch_size: ]
if not len(self):
self.input_texts = self.ori_inputs
return batch
def mask_tokens(self, inputs):
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.config.mlm_probability)
if self.config.special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = self.config.special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.config.prob_replace_mask)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
current_prob = self.config.prob_replace_rand / (1 - self.config.prob_replace_mask)
indices_random = torch.bernoulli(torch.full(labels.shape, current_prob)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
Then take some corpus for training, the format is very simple, that is, put all the text in a list, pay attention to the length not exceeding 512 tokens, otherwise the extra part will be wasted. Appropriate preprocessing can be done.
[
"这是一条文本",
"这是另一条文本",
...,
]
Then build the dataloader:
train_dataset = TrainDataset(training_texts, bert_tokenizer, config)
train_dataloader = DataLoader(train_dataset)
2.3 Training
Build a training method, the input parameters are our instantiated model to be trained, data set, and config:
def train(model, train_dataloader, config):
"""
训练
:param model: nn.Module
:param train_dataloader: DataLoader
:param config: Config
---------------
ver: 2021-11-08
by: changhongyu
"""
assert config.device.startswith('cuda') or config.device == 'cpu', ValueError("Invalid device.")
device = torch.device(config.device)
model.to(device)
if not len(train_dataloader):
raise EOFError("Empty train_dataloader.")
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
{
"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]
optimizer = AdamW(params=optimizer_grouped_parameters, lr=config.learning_rate, weight_decay=config.weight_decay)
for cur_epc in tqdm(range(int(config.epochs)), desc="Epoch"):
training_loss = 0
print("Epoch: {}".format(cur_epc+1))
model.train()
for step, batch in enumerate(tqdm(train_dataloader, desc='Step')):
input_ids = batch['inputs'].squeeze(0).to(device)
labels = batch['labels'].squeeze(0).to(device)
loss = model(input_ids=input_ids, labels=labels).loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.zero_grad()
training_loss += loss.item()
print("Training loss: ", training_loss)
Call it for training rounds:
train(model=bert_mlm_model, train_dataloader=train_dataloader, config=config)
2.4 Save and load
Students who have used the pre-trained model should know that ordinary bert has two outputs, namely the 768-dimensional encoding result corresponding to each token, and the sentence features used to represent the entire sentence.
And this sentence feature is obtained by pooling the original sentence by a Pooler module in the model. However, the training of this Pooler is not from the MLM task, but from the NSP task.
Since there is no NSP task, the Pooler cannot be trained, so there is no need to add the Pooler to the model.
Therefore, you need to save the embedding and encoder separately when saving,
and read the embedding and encoder separately when loading, so that the trained model cannot get the sentence representation of the CLS layer. If necessary, you can manually pool it.
torch.save(bert_mlm_model.bert.embeddings.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_eb.bin'.format(config.epochs)))
torch.save(bert_mlm_model.bert.encoder.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_ec.bin'.format(config.epochs)))
For loading, after instantiating the bert model, use bert's embedding component and encoder component to read the two weight files respectively.
At this point, the content of this issue is all over. I hope that students who have read this blog can have a deeper understanding of the basic principles of Bert.