[NLP | Natural Language Processing] BERT Prompt Text Classification (including source code)

1. Introduction to Prompt

Prompt is a new field of NLP. The description of the task in the prompt is embedded in the input, providing a new way to control the output of the machine learning model.

Prompt is to use the knowledge obtained by the pre-trained language model on a large amount of text data to solve various downstream tasks. The advantage of Prompt is that it can reduce or avoid fine-tuning the pre-trained model, saving computing resources and time, while maintaining or improving the performance and generalization ability of the model.

Prompt's method is to design appropriate input formats according to different tasks and data, including questions, contexts, prefixes, suffixes, separators, etc.

2. Use of BERT and Prompt

Prompt can be used to improve BERT's sentence representation ability. By adding some specific words as prompts to BERT's input, it can guide BERT to generate better sentence vectors:

  • Method 1: Add Prompt at the beginning or end of the sentence;
  • Method 2: Add Prompt in the middle of the sentence.

3. Prompt search method

The prompt's search method finds the optimal prompt, which can maximize the representation ability of BERT. There are currently three main search methods:

  • Random search: Randomly generate some prompts, then use them as the input of BERT, calculate the similarity between the output vector of BERT and the target vector, and select the prompt with the highest similarity as the optimal prompt.
  • Greedy search: Start with an empty prompt, add a word at the end of the prompt each time, then use it as the input of BERT, calculate the similarity between the output vector of BERT and the target vector, and select the word with the highest similarity as part of the prompt , until reaching a preset length or similarity threshold.
  • Reinforcement learning search: regard the generation of prompt as a sequential decision-making problem, use the algorithm of reinforcement learning to optimize a policy network, and update the parameters of the network according to a reward function.

4. Limitations of the Prompt method

The advantage of BERT + Prompt is that it can use Prompt to guide BERT to generate better sentence vectors, thereby improving the quality and diversity of sentence representation.

Sentence similarity, text classification, text retrieval, etc., BERT + Prompt may be more effective than the original BERT model. For text generation tasks, such as text summarization, text retelling, text continuation, etc., BERT + Prompt may not necessarily be more effective than the original BERT model.

Prompt is suitable for multi-task modeling, such as training multiple text tasks together. Therefore, in a single task, Prompt will not increase the model accuracy. I haven't seen the use case of Prompt in the existing text classification competition.

5. Case: Prompt text classification

Enter text:

It was [mask]. 文本输入样例

Connect the [MASK] output to the fully connected layer for classification.

5.1 Step 1: Define the model

class Bert_Model(nn.Module):
    def __init__(self,  bert_path ,config_file ):
        super(Bert_Model, self).__init__()
        self.bert = BertForMaskedLM.from_pretrained(bert_path,config=config_file)  # 加载预训练模型权重
 
 
    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask, token_type_ids) #masked LM 输出的是 mask的值 对应的ids的概率 ,输出 会是词表大小,里面是概率 
        logit = outputs[0]  # 池化后的输出 [bs, config.hidden_size]

        return logit 

5.2 Step 2: Define Dataset

class MyDataSet(Data.Dataset):
    def __init__(self, sen , mask , typ ,label ):
        super(MyDataSet, self).__init__()
        self.sen = torch.tensor(sen,dtype=torch.long)
        self.mask = torch.tensor(mask,dtype=torch.long)
        self.typ =torch.tensor( typ,dtype=torch.long)
        self.label = torch.tensor(label,dtype=torch.long)
 
    def __len__(self):
        return self.sen.shape[0]
 
    def __getitem__(self, idx):
        return self.sen[idx], self.mask[idx],self.typ[idx],self.label[idx]

5.3 Step 3: Add Prompt to the text

prefix = 'It was [mask]. '

for i in range(len(x_train)):
    text_ = prefix+x_train[i][0]
    encode_dict = tokenizer.encode_plus(text_,max_length=60,padding="max_length",truncation=True)

5.4 Step 4: Model Training and Prediction

optimizer = AdamW(model.parameters(),lr=2e-5,weight_decay=1e-4)  #使用Adam优化器
loss_func = nn.CrossEntropyLoss(ignore_index=-1)

for idx,(ids,att_mask,type,y) in enumerate(train_dataset):
    ids,att_mask,type,y = ids.to(device),att_mask.to(device),type.to(device),y.to(device)
    out_train = model(ids,att_mask,type)
    loss = loss_func(out_train.view(-1, tokenizer.vocab_size),y.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss_sum += loss.item()

Guess you like

Origin blog.csdn.net/wzk4869/article/details/130548901