How to Fully and Efficiently Train Multi-Round Dialogue Large Models

Understand in one article: How to fully and efficiently train multi-round dialogue large models

Eat jelly without spitting out jelly skin2023-08-08  13:08Published  in Zhejiang

The following article comes from YeungNLP, author He Fengweibai

YeungNLP.

Firefly is the official public account of the open source Chinese large language model.

【Click】Join large model technology exchange group

01

foreword

Recently, many small partners are asking questions and discussing how to train the multi-round dialogue ability of large models . This article will introduce in detail how the Firefly project makes full use of multi-round dialogue data to train large models. Note that we have focused on the two key words [Sufficient] and [Efficient], which are the characteristics of the Firefly project training multiple rounds of dialogue. This method may be different from the multi-round dialogue training method that most students understand.

02

model effect

Before introducing the multi-round dialogue training method, let's show the multi-round dialogue effect of the firefly-ziya-13b model trained by Firefly. The following replies are generated by models and have not been artificially modified.

Multi-round dialogue example 1:

picture

picture

Multi-round dialogue example 2:

picture

picture

picture

03

existing method

Suppose we now have a multi-round dialogue data, the content is as follows. For the convenience of explanation, for the nth round of dialogue, we set the input corresponding to the user and the assistant as Usern and Assistantn.

User1:你好Assistant1:你好,有什么能帮你吗?User2:今天天气怎么样Assistant2:北京今天天气晴,气温25度,紫外线较强,注意防护。User3:谢谢你Assistant3:不客气

Here is a pre-knowledge to facilitate our subsequent explanation. In the instruction fine-tuning stage, generally only the loss in the Assistant answer part will be used for gradient feedback to update the weight; while the loss in the User part will not be used to update the weight .

How to use the above multi-round dialogue data to train a large model? After discussion and research, we found that there are currently two main methods, but neither of them is sufficient and efficient.

method one

The texts of User1, Assistant1, User2, Assistant2, and User3 are all regarded as the input part of the model, and the text of Assistant3 is regarded as the prediction part of the model, and only the loss of the Assistant3 part participates in the weight update.

picture

The disadvantage of this method is that the training data of multiple rounds of conversations is not fully utilized, and the content of Assistant1 and Assistant2 does not participate in model training, and this part of data is wasted during training. And for a lot of multi-round dialogue data, the assistant reply part in the middle is more informative and detailed, and the last assistant reply part is often a relatively short text such as "thank you" and "you're welcome". If only this part of the text is used to train the model, it will seriously affect the training effect of the model.

Method Two

Split a piece of multi-round dialogue data into multiple pieces of data . For example, split the above example into the following three pieces of data.

picture

Compared with method 1, method 2 can make full use of the reply content of each Assistant in multiple rounds of dialogue. But the disadvantage is that a data containing n rounds of dialogue needs to be split into n pieces of data, which reduces the training efficiency by n times, and the training method is not efficient .

04

Firefly method

method introduction

When the Firefly project trains the multi-round dialogue model, it adopts a more efficient and efficient method. As shown in the figure below, after splicing a piece of multi-round dialogue data, we input it into the model and calculate the loss of each position in parallel. Only the loss of the Assistant part participates in the weight update.

picture

Why is this approach possible? The answer lies in the attention mask of the causal language model. The Causal Language Model (causal language model) represented by GPT, the attention mask of this model is a diagonal mask matrix. When each token is encoded, it can only see the token before it, but not after it. the token.

Therefore, the encoded output of the User1 part can only perceive the content of User1, but cannot perceive the text after it, and can be used to predict the content of Assistant1. As for the encoded output of User2, only the content of User1, Assistant1, and User2 can be seen, which can be used to predict the content of Assistant2, and so on. For the entire sequence, it only needs to be input into the model once, and the logits of each position can be obtained in parallel to calculate the loss.

picture

It is worth noting that GLM and UniLM do not belong to the strict Causal Language Model (causal language model), because they have a prefix attention mask design. For the prefix, its attention is bidirectional, while the attention of the prediction part is unidirectional.

picture

Code

Talk is cheap, Show me the code. Next, we will introduce how we fully and efficiently implement multi-round dialogue training from the code level.

During training, Firefly stitches multiple rounds of conversations into the following format, and then tokenizes them.

<s>input1</s>target1</s>input2</s>target2</s>...

If you prefer Alpaca or Vicuna's data organization style, you can also organize multiple rounds of dialogue into the following format. Speaking from personal experience, despite Firefly’s above-mentioned simple data organization form, the effect of multiple rounds of dialogue is amazing, so we tend not to add too many prefix descriptions, just for reference.

Below is a conversation between a user and an assistant.
User: input1Assistant: target1</s>User: input2Assistant: target2</s>...

One point to note is that during training, you need to add </s> after each Assistant's reply as the identifier for the end of this round of dialogue generation. Otherwise, during inference, it is difficult for the model to sample</s>, so that the generation cannot be completed.

When generating input_ids, we will also generate a target_mask with a value of 0 or 1, which is used to mark whether each token belongs to the target part, that is, whether the model is required for prediction. The target_mask of the "target</s>" part is all 1, and the other parts are all 0.

picture

We will calculate the loss of each position in parallel, but only the loss of some positions with target_mask=1 will participate in the weight update. This method makes full use of the advantages of model parallel computing, which is more efficient , and each target part in the multi-round dialogue is involved in the training, making full use of the data.

The implementation of loss calculation can refer to the following code:

class TargetLMLoss(Loss):
    def __init__(self, ignore_index):        super().__init__()        self.ignore_index = ignore_index        self.loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)
    def __call__(self, model, inputs, training_args, return_outputs=False):        input_ids = inputs['input_ids']        attention_mask = inputs['attention_mask']        target_mask = inputs['target_mask']        # 模型前馈预测        outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)        logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0]
        # 将labels中不属于target的部分,设为ignore_index,只计算target部分的loss        labels = torch.where(target_mask == 1, input_ids, self.ignore_index)        shift_logits = logits[..., :-1, :].contiguous()        shift_labels = labels[..., 1:].contiguous()        # Flatten the tokens        loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))        return (loss, outputs) if return_outputs else loss

05

epilogue

In this article, we introduce in detail the skills and implementation of the Firefly project to train the multi-round dialogue model, and realize a more efficient multi-round dialogue training method, hoping to help everyone better understand.

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/132183935