RLHF における PPO アルゴリズムの原理と実装
ChatGPT は、InstructGPT に基づいたマルチラウンド ダイアログ生成大規模モデルです。ChatGPT に主に関与するテクノロジーには次のものがあります。
- 命令チューニング。
- 因果言語モデリング。
- 人間の調整
ブロガーは、以前の記事で命令の微調整の原則と関連するプロンプト テクノロジを紹介し(参照:プロンプト チューニング—新しい微調整パラダイムの詳細な解釈)、GPT などの因果言語モデルに関する関連紹介: [事前トレーニング済み言語モデル] GPT: 生成的事前トレーニングによる言語理解を向上させる。したがって、大規模なモデルでは、基本的な生成モデルをトレーニングする方法に加えて、大規模な生成モデルを人間の価値観に合わせて作成する方法にも焦点を当てる必要があります。
前回の記事InstructGPT の原理説明と ChatGPT オープンソース プロジェクトでは、 ChatGPT と最近オープンソースになった ChatGPT モデルがどのようにアライメントを実現するかを紹介しましたが、ここでは InstructGPT におけるヒューマン アライメントのコア アルゴリズムである RLHF (ヒューマン アライメント強化学習) についても詳しく紹介します。 PPO アルゴリズム。
この記事では主に次の 2 つの参考資料を参照します。
[1]強化学習の簡単な概要: MDP、DP MC TC および Q 学習の一般的な理解、ポリシー勾配、PPO
[2] DeepSpeed に基づく ChatGPT のトレーニング
1. RLHF PPO アルゴリズムの原理
PPO アルゴリズムは、特定の Actor-Critic アルゴリズムの実装です。たとえば、対話ロボットでは、入力プロンプトは状態、出力応答はアクションです。望ましい戦略は、最大の報酬を得るためにプロンプトからアクションを生成する方法です。 、それは人間の好みに合わせるということです。
PPO アルゴリズムには、次の 2 つの戦略が含まれます。
- 近接ポリシー最適化ペナルティ (PPO ペナルティ)。
- プロキシマル戦略は、PPO クリップのクリッピングを最適化します。
重要度のサンプリング
Actor-Critic トレーニング中、ポリシー関数パラメーターが最適化された後は、ポリシー サンプリングの前のラウンドのアクション状態シーケンスを使用できないため、ポリシー関数の更新ごとにサンプリングが繰り返される問題を回避するために重要度サンプリングが必要です。 。データを分布 p ではサンプリングできず、別の分布 q からのみサンプリングできる場合 (q は任意の分布)。
重要度サンプリングの原則:
KL 発散制約:
重要度サンプリングでは、p 分布と q 分布をあまり遠くまでチェックすることはできないため、制約を課すために KL 発散が必要です。
アドバンテージ:
Actor-Critic アルゴリズムでは、利点を定義する必要があります。最も簡単な方法は、報酬ベースラインを定義することです。これは、 としても定義できます。ここで、V π ( s ) V_{\pi}(s)Vp( s )は現在の状態ssすべてのアクションがsおよび Q π ( s , a )で実行された後に得られる報酬の期待値Q_{\pi}(s, a)Qp( s 、a )現在の状態を表しますsssの下にアクションaa受け取った報酬。したがって、A π ( s , a ) > 0 A_{\pi}(s, a)>0あp( s 、)_>0、現在のアクションを意味しますaaaによって得られる報酬は全体の期待よりも大きいため、このアクションの確率は最大化される必要があります。
一般に、アドバンテージは、報酬のみが絶対値として使用される場合に発生する高分散の問題を回避し、正の値と負の値を通じてどのアクションが正のフィードバックを取得できるかを戦略に伝えることを目的としています。
利点+重要度のサンプリング:
アドバンテージは重要度サンプリングにおけるf ( x ) f(x)として考えることができます。f ( x )。最適化プロセス中にパラメーターが変化するため、重要度のサンプリングが必要となるため、最適化の目標は次のようになります。
J θ ' = E st , at 〜 π θ ' [ p θ ( at ∣ st ) p θ ' ( at , st ) A θ ' ( st , at ) ] J^{\theta'}=\mathbb{E} _{s_t, a_t}\sim\pi_{\theta'}\bigg[\frac{p_{\theta}(a_t|s_t)}{p_{\theta'}(a_t, s_t)}A^{\theta '}(s_t, a_t)\bigg]J私』=Esた、_た〜円周率私「[p私「( _た、sた)p私( _た∣ sた)あ私' (sた、あるた) ]
近接ポリシー最適化ペナルティ (PPO ペナルティ)
PPO アルゴリズムの近接ポリシー最適化ペナルティの原理を次の図に示します。
プロキシマル戦略 最適クリッピング PPO-clip
最適化目標は次のように変更されます。
式の理解:
KL 発散と比較して、KL 発散は 2 つの分布の出力ロジットに制約されますが、クリップ法は確率比に直接制約を適用します。
2. RLHF PPOアルゴリズムの実現
(1) まずRLHFクラスとPPOTrainerを初期化します。
rlhf_engine = DeepSpeedRLHFEngine(
actor_model_name_or_path=args.actor_model_name_or_path,
critic_model_name_or_path=args.critic_model_name_or_path,
tokenizer=tokenizer,
num_total_iters=num_total_iters,
args=args)
ppo_trainer = DeepSpeedPPOTrainer
trainer = ppo_trainer(rlhf_engine, args)
初期化中に、Actor、SFT、Critic、Reward を含む 4 つのモデルをロードします。
コード内の self.ref は実際には SFT モデルです
class DeepSpeedRLHFEngine():
def __init__(self, actor_model_name_or_path, critic_model_name_or_path,
tokenizer, args, num_total_iters):
self.args = args
self.num_total_iters = num_total_iters
self.tokenizer = tokenizer
# 用训练好的SFT模型初始化Actor模型
self.actor = self._init_actor(
actor_model_name_or_path=actor_model_name_or_path)
# 用训练好的SFT模型初始化SFT模型
self.ref = self._init_ref(
actor_model_name_or_path=actor_model_name_or_path)
self.actor_ema = None
if self.args.enable_ema:
self.actor_ema = self._init_ema(
actor_model_name_or_path=actor_model_name_or_path)
# 用训练好的RW初始化Critic模型
self.critic = self._init_critic(
critic_model_name_or_path=critic_model_name_or_path)
# 用训练好的RW初始化reward模型
self.reward = self._init_reward(
critic_model_name_or_path=critic_model_name_or_path)
if self.args.critic_gradient_checkpointing:
self.critic.gradient_checkpointing_enable()
(2) RLHFの学習データをロードする
prompt_train_dataloader, num_total_iters = create_datasets(
args=args, tokenizer=tokenizer, train_phase=3)
(3) RLHF の全体的なトレーニング プロセスは次のとおりです。
具体的なプロセスはコードのコメントに記載されています。一般に、主なプロセスは次のとおりです。
- 各エポックを横断し、各エポックの各バッチを横断します。
- バッチごとに、まず一連の経験的データをサンプリングします。
- 経験的データに基づいてアクター モデルと批評家モデルをトレーニングする
# 训练的总Epoch数
for epoch in range(args.num_train_epochs):
# 遍历每一个Batch
for step, (batch_prompt) in enumerate(prompt_train_dataloader):
batch_prompt = to_device(batch_prompt, device)
prompts = batch_prompt['prompt'] # prompt
length = prompts.size(-1)
# 进行采样,并加入到经验池,详见(3.1)
out = trainer.generate_experience(prompts)
exp_dataset = exp_mini_dataset.add(out)
if exp_dataset is not None:
inner_iter = 0
critic_loss, actor_loss = 0, 0
average_reward = 0
if args.actor_gradient_checkpointing:
rlhf_engine.actor.gradient_checkpointing_enable()
# 从经验池中进行学习Epoch轮
for ppo_ep in range(args.ppo_epochs):
for i, (exp_data) in enumerate(exp_dataset):
# 得到actor和critic loss,详见(3.2)
actor_loss, critic_loss = trainer.train_rlhf(exp_data)
critic_loss += actor_loss.item()
actor_loss += critic_loss.item()
average_reward += exp_data["rewards"].mean()
inner_iter += 1
if args.enable_ema:
moving_average(rlhf_engine.actor,
rlhf_engine.actor_ema,
zero_stage=args.actor_zero_stage)
# 每一轮结束后打乱经验池
random.shuffle(exp_dataset)
average_reward = get_all_reduce_mean(average_reward).item()
if args.actor_gradient_checkpointing:
rlhf_engine.actor.gradient_checkpointing_disable()
このトレーニング プロセスには主に 2 つの主要なステップが含まれます。
- エクスペリエンスデータのサンプリング。
- サンプリングされたデータに基づいてアクター モデルと批評家モデルをトレーニングします。
これら 2 つのコア ステップを詳細に分析しましょう。これら 2 つのコア ステップを理解すると、RLHF PPO アルゴリズムをほぼ理解できるようになります。
体験サンプリング
写真はここから。
実装の詳細については、コードとコメントを参照してください。
def generate_experience(self, prompts):
self.eval() # 开启eval模式
# 输入instruct prompt,由Actor生成seq,上图中红色步骤(1),seq由instruct和response组成
seq = self._generate_sequence(prompts)
self.train() # 恢复训练模型
pad_token_id = self.tokenizer.pad_token_id
attention_mask = seq.not_equal(pad_token_id).long()
with torch.no_grad():
# 将seq喂入actor中得到action_logits,上图中棕色步骤(2)
output = self.actor_model(seq, attention_mask=attention_mask)
# 将seq喂入SFT中得到sft_logits,上图中黑色步骤(5)
output_ref = self.ref_model(seq, attention_mask=attention_mask)
# 将seq喂入reward模型中打分,得到r(x, y),上图绿色步骤(4)
reward_score = self.reward_model.forward_value(
seq, attention_mask,
prompt_length=self.prompt_length)['chosen_end_scores'].detach(
)
# 将seq喂入critic,获得critic的value,上图蓝色步骤(3)
values = self.critic_model.forward_value(
seq, attention_mask, return_value_only=True).detach()[:, :-1]
logits = output.logits
logits_ref = output_ref.logits
# 获得经验数据
return {
'prompts': prompts,
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:, 1:]),
'value': values,
'rewards': reward_score,
'input_ids': seq,
"attention_mask": attention_mask
}
アドバンテージを取得し、アクターとクリティックのパラメーターを更新します
def train_rlhf(self, inputs):
# 当前RLHF轮次最初采样的经验池中采样一批数据
prompts = inputs['prompts'] # instruct prompt
log_probs = inputs['logprobs'] # actor模型生成response对应的action_logist
ref_log_probs = inputs['ref_logprobs'] # SFT模型生成response对应的sft_logits
reward_score = inputs['rewards'] # reward模型预测的奖励r(x, y)
values = inputs['value'] # critic模型预测的奖励
attention_mask = inputs['attention_mask']
seq = inputs['input_ids']
start = prompts.size()[-1] - 1
action_mask = attention_mask[:, 1:]
### 根据经验数据,接下来计算相应的reward和advantage
old_values = values # RLHF训练之前的critic给的value值
with torch.no_grad():
# 根据SFT的sft_logits和Actor的action_logist,计算KL散度;
# 并根据KL散度与reward模型预测的奖励r(x, y),获得最终奖励
# 上图中红色步骤(1)
old_rewards = self.compute_rewards(prompts, log_probs,
ref_log_probs, reward_score,
action_mask)
"""
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
action_mask):
# self.kl_ctl = 0.02
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
rewards = kl_divergence_estimate
start = prompts.shape[1] - 1
ends = start + action_mask[:, start:].sum(1)
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
self.clip_reward_value)
batch_size = log_probs.shape[0]
for j in range(batch_size):
rewards[j, start:ends[j]][-1] += reward_clip[j]
"""
# 由critic或的的value与前面根据KL散度和r(x, y)得到的reward,从而计算得到advantage
# 上图蓝色步骤(2)
advantages, returns = self.get_advantages_and_returns(
old_values, old_rewards, start)
"""
def get_advantages_and_returns(self, values, rewards, start):
# Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
for t in reversed(range(start, length)):
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values[:, start:]
return advantages.detach(), returns
"""
### 根据经验数据以及得到的advatage,下面开始获得一系列的loss
batch = {
'input_ids': seq, "attention_mask": attention_mask}
# 将这一批经验数据的seq(instruct prompt+response)再一次喂入actor得到logits
# 因为现在是在更新actor和critic,而经验数据所采用的actor和critic早已经是之前的了,所以
# 现在正在更新的actor和critic与当时进行经验采样时的actor、critic的参数已经有差异了;
# 所以需要重新获得当前最新的actor输出的logits
# 上图中棕色步骤(3)
actor_prob = self.actor_model(**batch, use_cache=False).logits
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],
inputs['input_ids'][:, 1:])
# 根据新的actor logits以及经验数据中的logits,以及advantage,计算actor loss
# 上图中绿色步骤(4)
actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
log_probs[:, start:], advantages,
action_mask[:, start:])
"""
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
## policy gradient loss
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
"""
# 更新actor模型参数
self.actor_model.backward(actor_loss)
self.actor_model.step()
# 经验数据中的seq(instruct prompt+response)再一次喂入critic得到value
# 同理,由于当前的critic和当初进行经验数据采样时的critic相差很远;所以需要重新获得value
# 上图中黑色步骤(5)
value = self.critic_model.forward_value(**batch,
return_value_only=True,
use_cache=False)[:, :-1]
# 根据最新的critic的value,经验数据的old_value,以及advatage,计算得到critic loss
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
start:],
returns, action_mask[:, start:])
"""
def critic_loss_fn(self, values, old_values, returns, mask):
## value loss
values_clipped = torch.clamp(
values,
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
return vf_loss
"""
# 更新critic参数
self.critic_model.backward(critic_loss)
self.critic_model.step()
return actor_loss, critic_loss
ブロガーは今後も大規模モデルに関するテクノロジーを更新し続けます。関連記事については、以下を参照してください。
【1】大規模モデルのトレーニングと推論の最適化技術について詳しく説明します
【2】Prompt-Tuning - 新しい微調整パラダイムの詳細な解釈
【3】InstructGPT の原理説明と ChatGPT オープンソース プロジェクト
【4】DeepSpeed に基づく ChatGPT トレーニング
【5】【HuggingFaceは簡単に始められます】Wikipediaに基づいた知識強化の事前トレーニング
【6】PytorchシングルマシンマルチカードGPUの実装(原理概要、基本フレームワーク、一般的なエラーレポート)