Rejection sampling of LLM large model training Trick series

From: NLP Workstation

Enter the NLP group —> join the NLP exchange group

Today, I will bring you an article about the rejection of sampling by the LLM model of the dikw boss (@知识dikw).

知乎:https://zhuanlan.zhihu.com/p/649731916

Read this article you want to learn and understand:

  • What is rejection sampling?

  • Which llm training uses RFT?

  • Why is rejection sampling necessary?

  • How much improvement does rejection sampling bring?

  • What is the relationship between rejection sampling and reinforcement learning?

  • The relationship between RFT and SFT?

  • Why can RFT bring improvement?

background introduction

Rejection sampling is a Monte Carlo algorithm for sampling data from a complex ("difficult to sample") distribution with the aid of a surrogate distribution.

What is Monte Carlo? If a method/algorithm uses random numbers to solve a problem, then it is classified as a Monte Carlo method. In the context of rejection sampling, Monte Carlo (also known as randomness) helps enforce the criteria in the algorithm. Regarding sampling, a core idea that exists in almost all Monte Carlo methods is that if you cannot sample from your target distribution function, then use another distribution function (hence called the proposal function). The above figure uses the Monte Carlo algorithm to estimate the area of ​​​​the circle and the "π value"f267ff87614a3f0515a6d28b4ef4b24e.png through the needle throwing experiment on the rectangle, and the frequency of falling inside the circle.

However, the sampling procedure must "follow the target distribution". Following a "target distribution" means that we should get a number of samples according to how likely they are to occur. In simple terms, high probability regions should have more samples.

This also means that when we use a proposal function, we must introduce necessary corrections to ensure that our sampling procedure follows the target distribution function! This "revised" aspect then takes the form of an accepted standard.

The main idea behind this method is: if we are trying to sample from the distribution p(x), we use another utility distribution q(x) to help sample from p(x). The only restriction is that for some M>1, p(x) < Mq(x). It is mainly used when the form of p(x) makes it difficult to sample directly, but it can be evaluated at any point x.

Here is a breakdown of the algorithm:

  1. Sample x from q(x).

  2. Sample y from U(0, Mq(x)) (uniform distribution).

  3. If y < p(x), accept x as a sample of p(x), otherwise return to step 1.

This approach works because the uniform distribution helps us scale the "envelope" provided by Mq(x) to the probability density function of p(x). Another way to look at it is that we sample the probability of point x0. This is proportional to the probability of sampling x0 from g, the proportion of times we accept, given simply by the ratio between p(x0) and Mq(x0). 23d66ede7d2163a70c855528abbd55c9.pngAbove, once we find a sample of q(x) (in this example, x=2), we sample from a uniform distribution with a range equal to the height of Mq(x). If it is within the height of the target probability density function, we accept it (indicated in green); otherwise, we reject it.

Combined with the background of our generative model here, the rejection sampling fine-tuning we mentioned here usually refers to performing K sample sampling on the basis of a fine-tuned model (maybe SFT fine-tuning or PPO algorithm fine-tuning, etc.). Then we have a rejection or acceptance function to filter the samples generated by model sampling to select samples that meet our target distribution, and then fine-tune the model.

Related research

Rejection sampling is a simple yet effective fine-tuning augmentation technique that is also used to align LLMs with human preferences.

WebGPT: Browser-assisted question-answering with human feedback

Rejection sampling (best-of-n). We sampled a fixed number of answers (4, 16 or 64) from either the BC model or the RL model (if left unspecified, we used the BC model), and selected the one that was ranked highest by the reward model. We used this as an alternative method of optimizing against the reward model, which requires no additional training, but instead uses more inference-time compute.73d99339757a7ad3782af724860eb194.pngEven though both rejection sampling and RL optimize against the same reward model, there are several possible reasons why rejection sampling outperforms RL:

  • 1.It may help to have many answering attempts, simply to make use of more inference-time compute.

  • 2.The environment is unpredictable: with rejection sampling, the model can try visiting many more websites, and then evaluate the information it finds with the benefit of hindsight.

  • 3.The reward model was trained primarily on data collected from BC and rejection sampling policies, which may have made it more robust to over optimization by rejection sampling than by RL.

  • 4..The reward model was trained primarily on data collected from BC and rejection sampling policies, which may have made it more robust to over optimization by rejection sampling than by RL.

Simply put, webgpt only uses rejection sampling in the inference phase, and does not use rejection sampling for fine-tuning. Then the author compared the effects of RL and rejection sampling, and found that rejection sampling would be better, and gave some explanations: it is more agreed that rejection sampling does not require parameter adjustment than the RL algorithm, and is more robust.

Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback

Rejection Sampling (RS) with a 52B preference model, where samples were generated from a 52B context-distilled LM. In this case the number k of samples was a parameter, but most often we used k = 16.

We also test our online models' performance during training (Figure 15), compare various levels of rejection sampling .c8f538c4d67cdba92a209d266e946f32.pngIn Figure 36 we show helpfulness Elo scores for a 52B context distilled model with rejection sampling (utilizing a 52B preference model trained on pure helpfulness) for k = 1, 4, 16, 64, showing that higher values of k clearly perform better. Note that the context distilled model and the preference models discussed here were trained during an earlier stage of our research with different datasets and settings from those discussed elsewhere in the paper, so they are not directly comparable with other Elo results, though very roughly and heuristically, our online models seem to perform about as well or better than k = 64 rejection sampling. Note that k = 64 rejection sampling corresponds to DKL = log(64) ≈ 4.2.404b35c4d0674f90993ed74cacea300f.png

To sum up, rejection sampling is still used in the inference stage, and the larger the K value, the better the effect when sampling. The online RLHF model seems to perform better than rejection sampling.

Aligning Large Language Models through Synthetic Feedback

An important additional component is that we leverage the synthetic RM from the previous stage to ensure the quality of the model-tomodel conversations with rejection sampling over the generated outputs (Ouyang et al., 2022). We train LLaMA-7B on the synthetic demonstrations (SFT) and further optimize the model with rewards from the synthetic RM, namely, Reinforcement Learning from Synthetic Feedback (RLSF).

To ensure a more aligned response from the assistant, we suggest including the synthetic RM, trained in the first stage, in the loop, namely Reward-Model-guided SelfPlay (RMSP). In this setup, the assistant model, LLaMA-30B- Faithful-3shot, first samples N responses for a given conversational context. Then, the RM scores the N responses, and the best-scored response is chosen as the final response for the simulation, ie, the RM performs rejection sampling (best-of -N sampling) (Nakano et al., 2021; Ouyang et al., 2022). Other procedures are the same as the Self-Play. Please see Figure 8 for the examples. The difference from the previous two articles is that the 1764ee7e0a1b9f035e651c1ebd933c97.pngrejection The sampled data is fine-tuned, and ICL is used to generate the responses of different levels of models to the prompt. Then, it is assumed that the response effect of the large model is better than that of the small model, and the preferred data is trained to obtain the RM model. Then use rejection sampling, use the RM model to select the response with the highest score to get the training set, and use SFT to train the model.

Llama 2: Open Foundation and Fine-Tuned Chat Models

9c7e4e20fbfb04c8c45c571d853d543e.pngThis process begins with the pretraining of Llama 2 using publicly available online sources. Following this, we create an initial version of Llama 2-Chat through the application of supervised fine-tuning. Subsequently, the model is iteratively refined using Reinforcement Learning with Human Feedback (RLHF) methodologies, specifically through rejection sampling and Proximal Policy Optimization (PPO). Throughout the RLHF stage, the accumulation of iterative reward modeling data in parallel with model enhancements is crucial to ensure the reward models remain within distribution.8abbe262d4dc85e4c9c0769bdc6c5554.pngRejection Sampling fine-tuning. We sample K outputs from the model and select the best candidate with our reward, consistent with Bai et al. (2022b). The same re-ranking strategy for LLMs was also proposed in Deng et al. (2019), where the reward is seen as an energy function. Here, we go one step further, and use the selected outputs for a gradient update. For each prompt, the sample obtaining the highest reward score is considered the new gold standard. Similar to Scialom et al. (2020a), we then fine-tune our model on the new set of ranked samples, reinforcing the reward.

The two RL algorithms mainly differ in:

  • Breadth — in Rejection Sampling, the model explores K samples for a given prompt, while only one generation is done for PPO.

  • Depth — in PPO, during training at step t the sample is a function of the updated model policy fromt − 1 after the gradient update of the previous step. In Rejection Sampling fine-tuning, we sample all the outputs given the initial policy of our model to collect a new dataset, before applying the fine-tuning similar to SFT. However, since we applied iterative model updates, the fundamental differences between the two RL algorithms are less pronounced.

To summarize the RLHF benchmarks used are PPO and Rejection Sampling (RS) fine-tuning (similar to best-of-N samples). PPO is the most popular on policy RL algorithm (it can be said to be trial and error learning). As mentioned here, we go one step further, and use the selected outputs for a gradient update. For each prompt, the sample obtaining the highest reward score is considered the new gold standard. Similar to Scialom et al. (2020a), We then fine-tune our model on the new set of ranked samples, reinforcing the reward.

It shows that llama uses rm to perform SFT training on the samples generated by rejection sampling to update the gradient of the policy model. At the same time, they also use the samples generated by rejection sampling as gold to retrain the RM model on the old checkpoint to strengthen the reward of the rm model. So the author thinks that the rejection sampling fine-tuning here is to fine-tune the SFT and RM models at the same time.

SCALING RELATIONSHIP ON LEARNING MATHEMATI-CAL REASONING WITH LARGE LANGUAGE MODELS

To augment more data samples for improving model performances without any human effort, we propose to apply Rejection sampling Fine-Tuning (RFT). RFT uses supervised models to generate and collect correct reasoning paths as augmented fine-tuning datasets. We find with augmented samples containing more distinct reasoning paths, RFT improves mathematical reasoning performance more for LLMs. We also find RFT brings more improvement for less performant LLMs. Furthermore, we combine rejection samples from multiple models which push LLaMA-7B to an accuracy of 49.3% and outperforms the supervised fine-tuning (SFT) accuracy of 35.9% significantly.79ebc775513315455b8611c6192ad391.png图中相比SFT模型RFT模型效果在GSM8k上面提升明显

In general to increase model performance by adding more data samples without any human effort, we propose to apply rejection sampling fine-tuning (RFT). RFT uses a supervised model to generate and collect correct inference paths as an augmented fine-tuning dataset. We find that RFT over LLM improves mathematical reasoning performance using augmented samples that contain more diverse reasoning paths. We also found that RFT brings more improvements to the less performant LLM. Furthermore, we combine rejected samples from multiple models, pushing LLAMA-7B to an accuracy of 49.3%, and significantly outperforms the Supervised Fine-Tuning (SFT) accuracy of 35.9%. It is worth noting that it is different from the RM model used above to perform rejection sampling to select the best response. The model response directly used here compares the answer with the correct answer and selects the correct result of reasoning.

RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment

However, the inefficiencies and instabilities associated with RL algorithms frequently present substantial obstacles to the successful alignment of generative models, necessitating the development of a more robust and streamlined approach. To this end, we introduce a new framework, Reward rAnked FineTuning (RAFT), designed to align generative models more effectively. Utilizing a reward model and a sufficient number of samples, our approach selects the high-quality samples, discarding those that exhibit undesired behavior, and subsequently assembles a streaming dataset. This dataset serves as the basis for aligning the generative model and can be employed under both offline and online settings. Notably, the sample generation process within RAFT is gradient-free, rendering it compatible with black-box generators. Through extensive experiments, we demonstrate that our proposed algorithm exhibits strong performance in the context of both large language models and diffusion models.4640fb70e8663f343cd5d3f9e775849e.pngd4299f82a17f49a3cc51b3431bdd3ce7.png

Summary and reflection

Rejection sampling makes the distribution of results output by the SFT model filtered by the rejection/acceptance function (here it can be a reward model or a heuristic rule), and the distribution of high-quality answers is obtained. Improved the performance of the final return. For rejection sampling, the larger the sample K, the better. At the same time, in the RLHF framework, the use of rejection sampling fine-tuning can be used to update the effect of the SFT model. For the ppo algorithm, it is often necessary to ensure that the distribution gap between the old strategy and the new strategy is relatively small, so the SFT model started by PPO is improved here. The effect is also very important for the PPO algorithm itself. Secondly, the sample fine-tuning of rejection sampling can be used to iterate the old reward model and strengthen the reward of the model. This is also very important for improving the final effect and iteration of PPO. At the same time, for the COT ability, rejection sampling provides more reasoning paths for model learning. This is also very important for models.


Enter the NLP group —> join the NLP exchange group

Guess you like

Origin blog.csdn.net/qq_27590277/article/details/132419439