Reinforcement Learning with Human Feedback (RLHF) in ChatGPT in action


Team Blog: CSDN AI Group


Related Reading


1 Introduction

In today's digital age, the popularity of ChatGPT continues to escalate. ChatGPT can handle complex language tasks, thereby liberating human resources, improving work efficiency and reducing costs. ChatGPT's advanced technology and wide application make it one of the hottest artificial intelligence technologies today. Whether it is an enterprise, an academic institution, or a technology enthusiast, they are all looking forward to the application prospects of ChatGPT.

In this context, the CSDN AI team also wants to simply reproduce ChatGPT. According to the ChatGPT official blog , the training method of ChatGPT is basically the same as that of InstructGPT (as shown in Figure 1), but the data sets used are different. Therefore, in terms of training methods, we mainly refer to InstructGPT for reproduction. The basic model uses RWKV . After splitting, it includes the following four stages:

  • (1) Language Model Pre-training;
  • (2) Supervised Fine-Tuning (SFT);
  • (3) Reward Model Training (Reward Modeling, RM);
  • (4) Use the proximal policy optimization algorithm for reinforcement learning (Proximal Policy Optimization, PPO).

The Pre-training and SFT of the (1) and (2) phases were completed by @zxm2015 , please refer to the article Exploring the Large Language Model 1 . This article mainly introduces the content of the (3) and (4) stages, that is, Reinforcement Learning from Human Feedback (RLHF).

insert image description here

Figure 1 The training process of the InstructGPT model

2 Reinforcement Learning with Human Feedback (RLHF)

Reinforcement Learning with Human Feedback (RLHF) is an algorithm used in ChatGPT to improve its answering performance. It is a reinforcement learning based approach to optimize ChatGPT responses by incorporating human feedback.

In RLHF, ChatGPT learns to improve the quality of its responses by interacting with human users. When ChatGPT generates an answer, it presents the answer to the user and asks for feedback from the user. Users can rate the answers, such as "good", "good", "average", "poor", etc. ChatGPT uses user feedback as a reward or punishment signal to update its model to better meet user needs.

RLHF can be divided into two parts. The first part is the reward model, where human feedback is mainly reflected; the second part adopts the reinforcement learning stage of the proximal strategy optimization algorithm, optimizes the model based on the feedback of the reward model, and finally obtains a language model that meets human preferences. These two parts will be described in detail below.

2.1 Reward Model (RM)

Before RLHF, the language model has been subjected to SFT (subsequently referred to as the SFT Model), and the task of the reward model is mainly to score the reply of the SFT Model. The higher the score, the better the answer. After the reward model is trained, it can be used in the next stage of PPO for reinforcement learning tuning. The reward model is a sub-part of PPO, which provides a reward signal for PPO training.

(1) Input and output of the model
The input of the model is the pair <Prompt, Response> of the user question (Prompt) and the SFT Model reply (Response), and the output is a reward score, as shown in the following figure:

insert image description here

Figure 2 Input and output of RM

(2) The construction of the data set
This stage is mainly to train the RM by manually labeling the training data, and the human feedback is reflected in this place. Randomly sample questions from the Prompts dataset, and for each question, generate K different responses. Human annotators rank these results comprehensively (for example, relevance, informativeness, harmful information, etc.) to give a ranking order.

According to the input and output description of the above reward model, when constructing the data set, <Prompt, Response> should be manually scored, but in fact it is difficult to score multiple answers, and the scores are continuous, which will slow down the labeling speed. In addition, what we actually focus on is which is better and which is worse among multiple options. Therefore, it is enough to sort multiple options when labeling. Finally, based on the sorted answers, build a data set and select an appropriate loss function.

Normally, humans perform sorting tasks. When the options are 4-9 (that is, K∈{4, 5, 6, 7, 8, 9}), the speed is the fastest and the effect is the most accurate. Here we set K= 4. In the end, we can get C(4, 2)=6 training samples for a Prompt.

Specifically, suppose we select a question x, and then use the SFT Model to generate 4 answers {y1, y2, y3, y4}, which are sorted by human annotators as y4 > y3 > y1 > y2}, then we get The training samples for are shown below, with <Prompt, Response> scores on the left being higher than those on the right:

(<x, y4>, <x, y3>)
(<x, y4>, <x, y1>)
(<x, y4>, <x, y2>)
(<x, y3>, <x, y1>)
(<x, y3>, <x, y2>)
(<x, y1>, <x, y2>)

(3) Loss function
According to the data set constructed above, we do not have a continuous scoring target to train the reward model, but there are positive and negative sample pairs, so the loss function is as follows, and the loss function needs to be minimized: Among them, r
insert image description here
( x,y) is the score of <x, y> input to the RM model, θ is the parameter of the RM, yw and yl are the different answers generated by the SFT Model when the input is x, and yw > yl when manually labeled.

# loss function
def loss_function(prefer_reward, alter_reward):
    return -torch.mean(torch.log(torch.sigmoid(prefer_reward - alter_reward)))

(4)
Compared with the SFT Model, the network structure of the core code RM does not need to be changed much. After inputting <Prompt, Response>, the embedding of the last token is directly taken, followed by a linear layer to calculate the reward score can

a) Linear layer:

# reward 得分计算
self.pred_reward = nn.Linear(dim, 1, bias=False)

b) forword function

    def forward(
        self,
        x,
        mask = None,
        prompt_mask = None,
        prompt_lengths = None
    ):

        # prompt_mask 和 prompt_lengths 只能二选一
        assert not (exists(prompt_mask) and exists(prompt_lengths))

        # derive prompt mask from prompt lengths
        if exists(prompt_lengths):
            batch, seq_len = x.shape
            arange = torch.arange(seq_len, device=x.device)
            prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')

        # reward model should have an understanding of which section is prompt, and which section is response
        # 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值
        # 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选
        prompt_response_mask_embed = torch.stack([
            self.prompt_embed,
            self.response_embed,
            self.padding_embed
        ]).to(prompt_mask.device)
        extra_embed = None
        if exists(prompt_mask):
            extra_embed = prompt_response_mask_embed[prompt_mask]            

        # 获得最后一个 token 的 embedding
        last_token_embeds = self.rwkv(
            x,
            extra_embed=extra_embed,
            rm_train=True
        )[:, -1, :]

        # 计算奖励
        reward = self.pred_reward(last_token_embeds)
        reward = reward.squeeze(-1)

        return reward

c) train_forward function

    def train_forward(self, x_p, x_a, m_p, m_a):
        # 因为前向传播的时候,需要过两次模型。所以反馈的时候需要冻结其中一次的参数
        # 不然梯度会被计算两次,在包含 deepspeed 框架下会报错
        # 报错信息:Gradient computed twice for this partition.

        with torch.enable_grad():
            prefer_reward = self.forward(x_p, prompt_mask=m_p)
        with torch.no_grad():
            alter_reward = self.forward(x_a, prompt_mask=m_a)

        return prefer_reward, alter_reward

2.2 Proximal Policy Optimization Algorithm (PPO)

Proximal Policy Optimization (PPO) is a deep reinforcement learning algorithm whose goal is to learn a policy that maximizes long-term cumulative returns.

insert image description here

Figure 3 Detailed version of PPO training architecture

(1) The PPO algorithm includes the following main parts:

  • a) Policy Network
    is used to learn and output the probability distribution of different actions in a given state. It is usually a neural network that can be updated based on feedback from the environment. Corresponding to the Actor in Figure 3, use the SFT Model to initialize, and need to participate in training in PPO.

  • b) Value Network
    is used to predict the expected return value for a given state. It is also usually a neural network, and its output can be used to compute the advantage function, thus helping to update the policy network. Corresponding to the Critic in Figure 3, use RM for initialization, and need to participate in training in PPO.

  • c) Reward model
    Corresponding to the Reward Model in Figure 3, it is the model trained in Section 2.1. It does not participate in the training in PPO and only provides reward signals for PPO training.

  • d) SFT Model
    corresponds to the Supervised Fine-Tune Model in Figure 3, which is used to update the policy network so that it can generate better policies. By limiting the magnitude of each update, it is ensured that the updated policy does not deviate too much from the original policy. This part may or may not be involved in training. When involved in training, PPO is called PPO-ptx.

  • e) Experience sampling
    It is used to collect experience data of interaction with the environment for the update of strategy network and value network. In the PPO algorithm, empirical sampling usually adopts a strategy based on action value estimation. Corresponds to the Prompts -> Actor -> Response process at the top of Figure 3.

insert image description here

Figure 4 Simplified version of PPO training architecture

(2) Loss function

  • a) actor loss (also known as policy loss, which is the loss of the model to be used in the end)
    insert image description here
    where πRL is the actor, and πSFT is the trained SFT Model. The first and second items of the loss function are the core parts, and the third item is optional. This loss function needs to be maximized. details as follows:
    • The first item: This item is the reward model RM reward score, and the reward needs to be maximized;
    • Second term: This term is used to penalize the RL policy for generating large deviations from the initial model in each training batch to ensure that the model outputs reasonably coherent text. If this penalty item is removed, the model may generate garbled text during optimization to fool the reward model to provide high reward values;
    • The third item: This item is the pre-training gradient (optional), which is generally not included in traditional PPO. This item is added to InstructGPT to avoid RLHF from causing large models to decline in public NLP evaluation tasks. It was named PPO-ptx after adding this item.
  • b) critic loss (also known as value loss)
    uses clipped_value_loss.

(3) Core code
a) training_step

    def training_step(self, batch, batch_idx, optimizer_idx):
        sequences, \
        prompt_masks, \
        masks, \
        old_action_probs, \
        old_log_probs, \
        rewards, \
        old_values = batch

        # PPO training
        action_masks = ~prompt_masks & masks

        action_logits, values = self.actor_critic(
            sequences,
            mask = action_masks
        )

        action_logits = shift(action_logits, shift=1, dim=-2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
        action_len = old_log_probs.shape[-1]

        action_probs = action_logits.softmax(dim = -1)
        action_log_probs = log_prob(action_probs, sequences)
        action_log_probs = action_log_probs[:, -action_len:]

        # calculate entropies, taking into account which part of the sequence is actually an action

        entropies = masked_entropy(action_probs, mask = action_masks)

        # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not

        kl_div_loss = 0.

        if self.args.kl_div_loss_weight > 0:
            kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight

        # handle non-pooled values

        normalize_kwargs = dict()

        if old_values.ndim == 2:
            old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values))

            old_values = old_values[:, -action_len:]
            values = values[:, -action_len:]
            rewards = rearrange(rewards, 'b -> b 1')
            normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:])

        if values.ndim < rewards.ndim:
            values = rearrange(values, '... -> ... 1')

        # calculate clipped surrogate objective, classic PPO loss

        ratios = (action_log_probs - old_log_probs).exp()
        advantages = masked_normalize(rewards - old_values, **normalize_kwargs)

        if advantages.ndim == 1:
            advantages = rearrange(advantages, 'b -> b 1')

        surr1 = ratios * advantages
        surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages
        policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies

        # actor loss (也称为 policy loss, 是最终要使用模型的 loss)
        if optimizer_idx == 0:
            actor_loss = policy_loss.mean() + kl_div_loss
            return actor_loss

        # critic loss (也称为 value loss)
        # update value network separate from policy network
        if optimizer_idx == 1:
            critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip)
            critic_loss = critic_loss.mean()
            return critic_loss

b) gen_experience_dataset

    def gen_experience_dataset(self):
        ''' 通过与 environment 交互产生训练数据
        '''
        
        device = self.device

        time_cnt = 0
        for eps in tqdm(range(self.args.num_episodes), desc = 'episodes'):
            for timestep in range(self.args.max_timesteps):
                time_cnt += 1

                # select a bunch of random states (prompts)
                # and get the action (sampled sequence from rwkv as well as the action probs)
                # also calculate the reward using reward model and store
                # 随机挑选一条 prompt
                rand_prompt_index = randrange(0, len(self.prompts))
                state = self.prompts[rand_prompt_index]

                # remove padding from state
                state_mask = state != self.args.pad_value
                state = state[state_mask]

                # get predicted sequence
                # 与 environment 进行交互,其中返回的:
                #   action 是 response,
                #   sequence 是 prompt + response, 
                (
                    actions,
                    sequence,
                    mask,
                    prompt_mask,
                    action_logits,
                    value
                ) = self.actor_critic.generate(
                    rearrange(state, 'n -> 1 n'),
                    max_seq_len = self.args.ctx_len,
                    return_values = True
                )
                action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token

                action_prob = action_logits.softmax(dim = -1)

                action_len = actions.shape[-1]
                action_log_prob = log_prob(action_prob, sequence)
                action_log_prob = action_log_prob[:, -action_len:]

                actions = rearrange(actions, '1 ... -> ...')

                # get reward as given by supervised trained reward model
                sequence = torch.cat((state, actions), dim = 0)

                prompt_length = len(state)
                prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length

                sequence = rearrange(sequence, 'n -> 1 n')
                prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
                mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device)

                reward = self.reward_model(
                    sequence,
                    prompt_mask = prompt_mask,
                    mask = mask,
                    sample = True
                )

                self.sequence_batch.append(sequence)
                self.prompt_mask_batch.append(prompt_mask)
                self.mask_batch.append(mask)
                self.action_prob_batch.append(action_prob)
                self.action_log_prob_batch.append(action_log_prob)
                self.reward_batch.append(reward)
                self.value_batch.append(value)

                if time_cnt % self.args.update_timesteps == 0:
                    train_data = zip(
                        self.sequence_batch, self.prompt_mask_batch, self.mask_batch, 
                        self.action_prob_batch, self.action_log_prob_batch, self.reward_batch, 
                        self.value_batch
                    )

                    for _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value in train_data:
                        yield _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value

                    self.sequence_batch.clear()
                    self.prompt_mask_batch.clear()
                    self.mask_batch.clear()
                    self.action_prob_batch.clear()
                    self.action_log_prob_batch.clear()
                    self.reward_batch.clear()
                    self.value_batch.clear()

3 Summary

RLHF can continuously learn and optimize dialogues based on user feedback, thereby improving the quality and effectiveness of dialogues. However, due to the limitation of computing power resources, we simply debugged and pulled through the training process of RLHF, and have not trained the model on the actual data set yet. If there are mistakes, please correct me, thank you!

4 Reference

[1] InstructGPT
[2] The "hero" behind ChatGPT - RLHF technical details
[3] ColossalAI
[4] PaLM-rlhf-pytorch
[5] Promixal Policy Optimization with PyTorch
[6] How ChatGPT Works Part 2: The Reward Model

Guess you like

Origin blog.csdn.net/u010280923/article/details/130283628