Table of contents
Team Blog: CSDN AI Group
Related Reading
- Introduction to ChatGPT
- A Brief Discussion on Large Language Models 1
- 10 Must-See Papers About ChatGPT
- From ELMo to ChatGPT: Counting the must-see large models of NLP in the past 5 years
1 Introduction
In today's digital age, the popularity of ChatGPT continues to escalate. ChatGPT can handle complex language tasks, thereby liberating human resources, improving work efficiency and reducing costs. ChatGPT's advanced technology and wide application make it one of the hottest artificial intelligence technologies today. Whether it is an enterprise, an academic institution, or a technology enthusiast, they are all looking forward to the application prospects of ChatGPT.
In this context, the CSDN AI team also wants to simply reproduce ChatGPT. According to the ChatGPT official blog , the training method of ChatGPT is basically the same as that of InstructGPT (as shown in Figure 1), but the data sets used are different. Therefore, in terms of training methods, we mainly refer to InstructGPT for reproduction. The basic model uses RWKV . After splitting, it includes the following four stages:
- (1) Language Model Pre-training;
- (2) Supervised Fine-Tuning (SFT);
- (3) Reward Model Training (Reward Modeling, RM);
- (4) Use the proximal policy optimization algorithm for reinforcement learning (Proximal Policy Optimization, PPO).
The Pre-training and SFT of the (1) and (2) phases were completed by @zxm2015 , please refer to the article Exploring the Large Language Model 1 . This article mainly introduces the content of the (3) and (4) stages, that is, Reinforcement Learning from Human Feedback (RLHF).
2 Reinforcement Learning with Human Feedback (RLHF)
Reinforcement Learning with Human Feedback (RLHF) is an algorithm used in ChatGPT to improve its answering performance. It is a reinforcement learning based approach to optimize ChatGPT responses by incorporating human feedback.
In RLHF, ChatGPT learns to improve the quality of its responses by interacting with human users. When ChatGPT generates an answer, it presents the answer to the user and asks for feedback from the user. Users can rate the answers, such as "good", "good", "average", "poor", etc. ChatGPT uses user feedback as a reward or punishment signal to update its model to better meet user needs.
RLHF can be divided into two parts. The first part is the reward model, where human feedback is mainly reflected; the second part adopts the reinforcement learning stage of the proximal strategy optimization algorithm, optimizes the model based on the feedback of the reward model, and finally obtains a language model that meets human preferences. These two parts will be described in detail below.
2.1 Reward Model (RM)
Before RLHF, the language model has been subjected to SFT (subsequently referred to as the SFT Model), and the task of the reward model is mainly to score the reply of the SFT Model. The higher the score, the better the answer. After the reward model is trained, it can be used in the next stage of PPO for reinforcement learning tuning. The reward model is a sub-part of PPO, which provides a reward signal for PPO training.
(1) Input and output of the model
The input of the model is the pair <Prompt, Response> of the user question (Prompt) and the SFT Model reply (Response), and the output is a reward score, as shown in the following figure:
(2) The construction of the data set
This stage is mainly to train the RM by manually labeling the training data, and the human feedback is reflected in this place. Randomly sample questions from the Prompts dataset, and for each question, generate K different responses. Human annotators rank these results comprehensively (for example, relevance, informativeness, harmful information, etc.) to give a ranking order.
According to the input and output description of the above reward model, when constructing the data set, <Prompt, Response> should be manually scored, but in fact it is difficult to score multiple answers, and the scores are continuous, which will slow down the labeling speed. In addition, what we actually focus on is which is better and which is worse among multiple options. Therefore, it is enough to sort multiple options when labeling. Finally, based on the sorted answers, build a data set and select an appropriate loss function.
Normally, humans perform sorting tasks. When the options are 4-9 (that is, K∈{4, 5, 6, 7, 8, 9}), the speed is the fastest and the effect is the most accurate. Here we set K= 4. In the end, we can get C(4, 2)=6 training samples for a Prompt.
Specifically, suppose we select a question x, and then use the SFT Model to generate 4 answers {y1, y2, y3, y4}, which are sorted by human annotators as y4 > y3 > y1 > y2}, then we get The training samples for are shown below, with <Prompt, Response> scores on the left being higher than those on the right:
(<x, y4>, <x, y3>)
(<x, y4>, <x, y1>)
(<x, y4>, <x, y2>)
(<x, y3>, <x, y1>)
(<x, y3>, <x, y2>)
(<x, y1>, <x, y2>)
(3) Loss function
According to the data set constructed above, we do not have a continuous scoring target to train the reward model, but there are positive and negative sample pairs, so the loss function is as follows, and the loss function needs to be minimized: Among them, r
( x,y) is the score of <x, y> input to the RM model, θ is the parameter of the RM, yw and yl are the different answers generated by the SFT Model when the input is x, and yw > yl when manually labeled.
# loss function
def loss_function(prefer_reward, alter_reward):
return -torch.mean(torch.log(torch.sigmoid(prefer_reward - alter_reward)))
(4)
Compared with the SFT Model, the network structure of the core code RM does not need to be changed much. After inputting <Prompt, Response>, the embedding of the last token is directly taken, followed by a linear layer to calculate the reward score can
a) Linear layer:
# reward 得分计算
self.pred_reward = nn.Linear(dim, 1, bias=False)
b) forword function
def forward(
self,
x,
mask = None,
prompt_mask = None,
prompt_lengths = None
):
# prompt_mask 和 prompt_lengths 只能二选一
assert not (exists(prompt_mask) and exists(prompt_lengths))
# derive prompt mask from prompt lengths
if exists(prompt_lengths):
batch, seq_len = x.shape
arange = torch.arange(seq_len, device=x.device)
prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')
# reward model should have an understanding of which section is prompt, and which section is response
# 根据 prompt_mask 中 token 的 True 和 False,从 prompt_embed 或 response_embed 中取值
# 如果为 True,则从 prompt_embed 中选,否则从 response_embed 中选
prompt_response_mask_embed = torch.stack([
self.prompt_embed,
self.response_embed,
self.padding_embed
]).to(prompt_mask.device)
extra_embed = None
if exists(prompt_mask):
extra_embed = prompt_response_mask_embed[prompt_mask]
# 获得最后一个 token 的 embedding
last_token_embeds = self.rwkv(
x,
extra_embed=extra_embed,
rm_train=True
)[:, -1, :]
# 计算奖励
reward = self.pred_reward(last_token_embeds)
reward = reward.squeeze(-1)
return reward
c) train_forward function
def train_forward(self, x_p, x_a, m_p, m_a):
# 因为前向传播的时候,需要过两次模型。所以反馈的时候需要冻结其中一次的参数
# 不然梯度会被计算两次,在包含 deepspeed 框架下会报错
# 报错信息:Gradient computed twice for this partition.
with torch.enable_grad():
prefer_reward = self.forward(x_p, prompt_mask=m_p)
with torch.no_grad():
alter_reward = self.forward(x_a, prompt_mask=m_a)
return prefer_reward, alter_reward
2.2 Proximal Policy Optimization Algorithm (PPO)
Proximal Policy Optimization (PPO) is a deep reinforcement learning algorithm whose goal is to learn a policy that maximizes long-term cumulative returns.
(1) The PPO algorithm includes the following main parts:
-
a) Policy Network
is used to learn and output the probability distribution of different actions in a given state. It is usually a neural network that can be updated based on feedback from the environment. Corresponding to the Actor in Figure 3, use the SFT Model to initialize, and need to participate in training in PPO. -
b) Value Network
is used to predict the expected return value for a given state. It is also usually a neural network, and its output can be used to compute the advantage function, thus helping to update the policy network. Corresponding to the Critic in Figure 3, use RM for initialization, and need to participate in training in PPO. -
c) Reward model
Corresponding to the Reward Model in Figure 3, it is the model trained in Section 2.1. It does not participate in the training in PPO and only provides reward signals for PPO training. -
d) SFT Model
corresponds to the Supervised Fine-Tune Model in Figure 3, which is used to update the policy network so that it can generate better policies. By limiting the magnitude of each update, it is ensured that the updated policy does not deviate too much from the original policy. This part may or may not be involved in training. When involved in training, PPO is called PPO-ptx. -
e) Experience sampling
It is used to collect experience data of interaction with the environment for the update of strategy network and value network. In the PPO algorithm, empirical sampling usually adopts a strategy based on action value estimation. Corresponds to the Prompts -> Actor -> Response process at the top of Figure 3.
(2) Loss function
- a) actor loss (also known as policy loss, which is the loss of the model to be used in the end)
where πRL is the actor, and πSFT is the trained SFT Model. The first and second items of the loss function are the core parts, and the third item is optional. This loss function needs to be maximized. details as follows:- The first item: This item is the reward model RM reward score, and the reward needs to be maximized;
- Second term: This term is used to penalize the RL policy for generating large deviations from the initial model in each training batch to ensure that the model outputs reasonably coherent text. If this penalty item is removed, the model may generate garbled text during optimization to fool the reward model to provide high reward values;
- The third item: This item is the pre-training gradient (optional), which is generally not included in traditional PPO. This item is added to InstructGPT to avoid RLHF from causing large models to decline in public NLP evaluation tasks. It was named PPO-ptx after adding this item.
- b) critic loss (also known as value loss)
uses clipped_value_loss.
(3) Core code
a) training_step
def training_step(self, batch, batch_idx, optimizer_idx):
sequences, \
prompt_masks, \
masks, \
old_action_probs, \
old_log_probs, \
rewards, \
old_values = batch
# PPO training
action_masks = ~prompt_masks & masks
action_logits, values = self.actor_critic(
sequences,
mask = action_masks
)
action_logits = shift(action_logits, shift=1, dim=-2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_len = old_log_probs.shape[-1]
action_probs = action_logits.softmax(dim = -1)
action_log_probs = log_prob(action_probs, sequences)
action_log_probs = action_log_probs[:, -action_len:]
# calculate entropies, taking into account which part of the sequence is actually an action
entropies = masked_entropy(action_probs, mask = action_masks)
# calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
kl_div_loss = 0.
if self.args.kl_div_loss_weight > 0:
kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight
# handle non-pooled values
normalize_kwargs = dict()
if old_values.ndim == 2:
old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values))
old_values = old_values[:, -action_len:]
values = values[:, -action_len:]
rewards = rearrange(rewards, 'b -> b 1')
normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:])
if values.ndim < rewards.ndim:
values = rearrange(values, '... -> ... 1')
# calculate clipped surrogate objective, classic PPO loss
ratios = (action_log_probs - old_log_probs).exp()
advantages = masked_normalize(rewards - old_values, **normalize_kwargs)
if advantages.ndim == 1:
advantages = rearrange(advantages, 'b -> b 1')
surr1 = ratios * advantages
surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages
policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies
# actor loss (也称为 policy loss, 是最终要使用模型的 loss)
if optimizer_idx == 0:
actor_loss = policy_loss.mean() + kl_div_loss
return actor_loss
# critic loss (也称为 value loss)
# update value network separate from policy network
if optimizer_idx == 1:
critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip)
critic_loss = critic_loss.mean()
return critic_loss
b) gen_experience_dataset
def gen_experience_dataset(self):
''' 通过与 environment 交互产生训练数据
'''
device = self.device
time_cnt = 0
for eps in tqdm(range(self.args.num_episodes), desc = 'episodes'):
for timestep in range(self.args.max_timesteps):
time_cnt += 1
# select a bunch of random states (prompts)
# and get the action (sampled sequence from rwkv as well as the action probs)
# also calculate the reward using reward model and store
# 随机挑选一条 prompt
rand_prompt_index = randrange(0, len(self.prompts))
state = self.prompts[rand_prompt_index]
# remove padding from state
state_mask = state != self.args.pad_value
state = state[state_mask]
# get predicted sequence
# 与 environment 进行交互,其中返回的:
# action 是 response,
# sequence 是 prompt + response,
(
actions,
sequence,
mask,
prompt_mask,
action_logits,
value
) = self.actor_critic.generate(
rearrange(state, 'n -> 1 n'),
max_seq_len = self.args.ctx_len,
return_values = True
)
action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
action_prob = action_logits.softmax(dim = -1)
action_len = actions.shape[-1]
action_log_prob = log_prob(action_prob, sequence)
action_log_prob = action_log_prob[:, -action_len:]
actions = rearrange(actions, '1 ... -> ...')
# get reward as given by supervised trained reward model
sequence = torch.cat((state, actions), dim = 0)
prompt_length = len(state)
prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length
sequence = rearrange(sequence, 'n -> 1 n')
prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device)
reward = self.reward_model(
sequence,
prompt_mask = prompt_mask,
mask = mask,
sample = True
)
self.sequence_batch.append(sequence)
self.prompt_mask_batch.append(prompt_mask)
self.mask_batch.append(mask)
self.action_prob_batch.append(action_prob)
self.action_log_prob_batch.append(action_log_prob)
self.reward_batch.append(reward)
self.value_batch.append(value)
if time_cnt % self.args.update_timesteps == 0:
train_data = zip(
self.sequence_batch, self.prompt_mask_batch, self.mask_batch,
self.action_prob_batch, self.action_log_prob_batch, self.reward_batch,
self.value_batch
)
for _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value in train_data:
yield _sequence, _prompt_mask, _mask, _action_prob, _action_log_prob, _reward, _value
self.sequence_batch.clear()
self.prompt_mask_batch.clear()
self.mask_batch.clear()
self.action_prob_batch.clear()
self.action_log_prob_batch.clear()
self.reward_batch.clear()
self.value_batch.clear()
3 Summary
RLHF can continuously learn and optimize dialogues based on user feedback, thereby improving the quality and effectiveness of dialogues. However, due to the limitation of computing power resources, we simply debugged and pulled through the training process of RLHF, and have not trained the model on the actual data set yet. If there are mistakes, please correct me, thank you!
4 Reference
[1] InstructGPT
[2] The "hero" behind ChatGPT - RLHF technical details
[3] ColossalAI
[4] PaLM-rlhf-pytorch
[5] Promixal Policy Optimization with PyTorch
[6] How ChatGPT Works Part 2: The Reward Model