Fine-tuning Llama2 using DPO

Introduction

0a8c71720b4dabdfcaae8d442bd45bcc.jpeg

Reinforcement Learning from Human Feedback (RLHF) has actually become the last step in LLM training such as GPT-4 or Claude, which can ensure that the output of the language model meets human expectations in terms of chatting or security. However, it also introduces some RL-related complexities to NLP: it is necessary to construct a good reward function and train a model to estimate the value of each state; it is also necessary to note that the final generated LLM cannot be compared with the original The models are too far apart, which makes the model prone to producing gibberish rather than meaningful text. The process is very complex and involves many complex components, and these components themselves change dynamically during the training process, so it is not easy to manage them well.

Rafailov, Sharma, Mitchell and others recently published a paper, Direct Preference Optimization, which proposes to convert the reinforcement learning-based objectives used by existing methods into objectives that can be directly optimized through a simple binary cross-entropy loss. This approach greatly The purification process of LLM is simplified.

This article introduces the Direct Preference Optimization (DPO) method, which is now integrated into the TRL library. At the same time, we also show how to fine-tune the latest Llama v2 7B model on the stack-exchange preference dataset, which contains various questions and their rankings on various stack-exchange portals The final answer.

DPO to PPO

When optimizing human-derived preferences through RL, the traditional approach has been to use an auxiliary reward model to fine-tune the target model to maximize the rewards that the target model can obtain through the RL mechanism. Intuitively, we use the reward model to provide feedback to the model to be optimized to encourage it to generate more high-reward outputs and less low-reward outputs. At the same time, we use a frozen reference model to ensure that the output deviation is not too large and continues to maintain output diversity. This usually requires adding a KL penalty term relative to the reference model in addition to the reward maximization objective when designing the objective function. This helps prevent model learning from cheating or exploiting reward models.

DPO bypasses the step of modeling the reward function, which stems from a key insight: the analytical mapping from the reward function to the optimal RL policy. This mapping intuitively measures how well a given reward function matches given preference data. With it, the authors can directly convert the RL loss based on the reward and reference model to a loss based only on the reference model, thus optimizing the language model directly on the preference data! Therefore, DPO starts by looking for the best solution to minimize the RLHF loss, by changing the parameters to derive a loss that only requires the reference model!

With it, we can directly optimize this likelihood objective without the need for reward models or tedious reinforcement learning optimization processes.

How to use TRL for training

As mentioned before, a typical RLHF pipeline usually contains the following links:

  1. Supervised fine-tuning (SFT)
  2. Label data with preference labels
  3. Train a reward model based on preference data
  4. RL optimization

The TRL library contains the tools needed for all of these steps. DPO training directly eliminates the two links of reward modeling and RL (links 3 and 4), and directly optimizes the DPO goal based on the marked preference data.

Using DPO, we still need to perform step 1, but we only need to provide the preference data prepared in step 2 to the  DPOTrainer in TRL, and steps 3 and 4 are no longer needed. The annotated preference data needs to follow a specific format, which is a dictionary containing the following 3 keys:

  • prompt : the prompt input to the model during inference
  • chosen : the better answer to the given prompt
  • rejected :  i.e. a poor answer to a given prompt or an answer that is not a given prompt

For example, for the stack-exchange preference dataset, we can use the following utility function to map the samples in the dataset to the above dictionary format and remove all original columns:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"], # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)

Once you have a sorted dataset, the DPO loss is essentially a supervised loss that is implicitly rewarded via the reference model. Therefore, from the upper level, DPOTrainer requires us to input the basic model and reference model to be optimized:

dpo_trainer = DPOTrainer(
    model, # The basic model of SFT 
; A copy of SFT 's base model&
nbsp
; Prepared data set
    tokenizer = tokenizer, 
; size, learning rate, etc.
)

Among them, the super parameter beta is the temperature of DPO loss, usually between 0.1 and 0.5. It controls how much we pay attention to the reference model. The smaller the beta, the more we ignore the reference model. After initializing the trainer, we can simply call the following method to train on the given data set using the given training_args:

dpo_trainer.train()

Experiment based on Llama v2

The advantage of implementing the DPO trainer in TRL is that one can take advantage of the existing LLM-related functions in TRL and its dependent libraries (such as Peft and Accelerate). With these libraries, we can even use the QLoRA technology provided by the bitsandbytes library to train the Llama v2 model.

Supervised fine-tuning

As mentioned above, we first use TRL's SFTTrainer to perform supervised fine-tuning on the 7B Llama v2 model using QLoRA on the SFT data subset:

# load the base model in 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name, # "meta-llama/Llama-2-7b-hf"
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True,
    use_auth_token=True,
)
base_model.config.use_cache = False

# add LoRA layers on top of the quantized base model
peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args, # HF Trainer arguments
)
trainer.train()

DPO training

After SFT is completed, we save the generated model. Then, we continue to train DPO. We use the model generated by SFT as the basic model and reference model of DPO, and train the model with DPO as the objective function on the stack-exchange preference data generated above. We choose to fine-tune the model for LoRa, so we load the model using Peft's AutoPeftModelForCausalLM function:

model = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # location of saved SFT model
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # same model as the main one
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

As can be seen, we load the model in 4-bit mode and then train it by selecting the QLora method through the peft_config parameter. The trainer also uses the evaluation data set to evaluate training progress and reports some key metrics, such as optionally recording and displaying implicit rewards through WandB. Finally, we can push the trained model to HuggingFace Hub.

Summarize

The complete source code of the SFT and DPO training scripts can be found in the directory examples/stack_llama_2, and the trained merged model has also been uploaded to HF Hub (see here).

Guess you like

Origin blog.csdn.net/specssss/article/details/132495138