The trick of large model RLHF

From: Baobao Algorithm Notes

Enter the NLP group —> join the NLP exchange group

I mentioned a question in the article Big Model Interview Stereotype before, the RM score is getting higher and higher in the big model training, so the effect of training LLM must be good?

Such a positive judgment must be flawed and doubtful.

If you run the process of ppo a few times, you will find that the reinforcement learning of large models is very difficult to train. The difficulty of training not only refers to the fee card, but also means that it is very easy to train and collapse.

First, fee card. Suppose you train llama 7b, SFT and RM both use 7B models, then memory consumption = 2*7B(TRIAN MODE) + *7B(EVAL MODE), corresponding to policy model / critic model, and ref model/reward model

Originally, you could use a few 40GB A100 cards + deepspeed to do 7b full parameter fine-tuning, but you have to upgrade to 80GB A100 for reinforcement learning, and you can barely run to 7B. If you want to run bigger, you have to charge money.

Second, it is easy to collapse. LLM will not listen to you after being trained, or it will become an unstoppable repeater, with no logic in the output until maxlen, or it will become dumb and lie flat with an eosid.

The problems in RLHF are actually very common in RL game training. If the environment and parameter settings are not good, the agent is easy to go to extremes, jumping repeatedly between killing or looping ghosts and animals.

The original ppo is very difficult to train. It has high requirements for the SFT base model and RM training data and sampling prompt data, and the parameter setting requirements are also very high.

Since openai brought a wave of RLHF rhythm, everyone feels that reinforcement learning is invincible in alignment, but it seems that running by yourself is not the case. This thing is too particular.

More devils are in the details. Openai is like winning a competition and telling you the successful solution, but the result didn't tell you the importance and key settings of each step, let alone the failure and invalid experience.

I would like to praise Fudan here. Fudan put forward 7 tricks in a very solid technical report. It not only tells you the key trick settings, but also tells you the failure experience. If the model is not good enough, the model will be killed.

Reference https://github.com/OpenLMLab/MOSS-RLHF

Before talking about the trick, first of all, Fudan-MOSS also added more monitoring to the LLM training process. In fact, these are very important monitoring process indicators in the experiment, and can clearly find out whether your model is abnormal.

Then this picture is very good, it very clearly describes how the trick works in each stage of RLHF, and the supporting open source code implementation is also very clear and easy to understand, the typical noodle code has no encapsulation, one code to the end, easy to read It is very convenient to change the magic.

Let's take a look at these 7 tricks, which correspond to the part marked with an asterisk on the right side of the figure.

7820340c6ca59e91195d9c85f60e49ad.png

1. KL divergence penalty at the token level

kl_penalty = (-self.kl_penalty_weight * (logprobs - ref_logprobs)).cpu()

The main problem to be solved in this step is training stability, to prevent the steps from being too big, and if the difference between your output and the reference model is too large, the points will be deducted.

2,Reward Normalization and Clipping

3,Value Function Loss Clipping

Clipping is similar to gradient clipping, but it is too big to stop, and it limits some abnormal losses and rewards. Normalization is to standardize rewards.

This part of the code can be carefully checked corresponding to these settings in the open source, the principle is similar

self.use_reward_clip: bool = opt.use_reward_clip
self.use_reward_norm: bool = opt.use_reward_norm
self.use_advantage_norm: bool = opt.use_advantage_norm
self.use_advantage_clip: bool = opt.use_advantage_clip
self.use_critic_loss_clip: bool = opt.use_critic_loss_clip
self.use_policy_loss_clip: bool = opt.use_policy_loss_clip

4.Critic Model Initialization

Initializing Critic with RM model may not be a necessary choice. The author has done some experiments to prove this problem and recommends using critic model pre-training. This part of the code has not yet been initialized with rm. Follow up on this issue.

5. Generalized Advantage Estimation

C.3 in the appendix contains GAE parameter tuning experiments.

8ef747bd9203d4c3d0693c6151270eb9.png

6.Clipped Surrogate Objective

This is also a regularization method to prevent the steps from being too big and ensure the stability during the training process. This method is more efficient than the general policy gradient processing.

d3189569e9a9e26eedb0d63f7643667a.png

7.Global Gradient Clipping

ddc8dec80de02c6c1a1779d6b522feb4.png

The principle is still the same as above, all Clipping is nothing more than cutting too big a step.

In addition, the author also used a scheme used in the instruct gpt, adding the training process to use llm_pretrain_loss, reference code

if self.use_entropy_loss:
    loss1 = pg_loss + self.vf_loss_weight * vf_loss + self.entropy_loss_weight * entro_loss
else:
    loss1 = pg_loss + self.vf_loss_weight * vf_loss
loss2 = self.ppo_pretrain_loss_weight * pretrain_loss
loss = loss1 + loss2

To sum up, the improvement of the overall ppo-max mainly focuses on the stability of the training process. The things used are still the old three of the model. The training process is cut, initialized, and loss is improved. The main focus is on how to make RLHF more tuned. Recommended Refer to the author's source code for some experiments.

In addition, the author left an easter egg in the paper. The second part of the technical report mainly talks about the success of the Reward Model and the experience of stepping into the pit. It has not been released yet, and the author is waiting for an update. Everyone has been arguing about which scale RM to use before, saying that it is necessary to use an RM model that is much larger than the SFT model. an experiment~


Enter the NLP group —> join the NLP exchange group

Guess you like

Origin blog.csdn.net/qq_27590277/article/details/131778090