フレームワークの使用法と RLHF 実践におけるいくつかの落とし穴 (TRL、LMFlow)

From: ハグフェイス

NLP グループに入る —> NLP 交換グループに参加する

1 はじめに

以前、いくつかの一般的な RLHF フレームワークの体験をまとめた記事を見たことがありますが、Hugging Face が管理する TRL ライブラリに関する関連記事は見たことがないようです。使用中に遭遇した落とし穴を共有する記事を書いて、完全なプロセス フレームワーク LMFlow についても紹介します。

beff4b0835ce6ed828a407874afc5455.png

LMFlow フレームワークの概略図。

主に具体的な例を使用して、2 つのフレームワークの下で RLHF を行う方法を示し、トレーニング プロセス中に踏んだ主なピットを記録します。この例には、完全な SFT、報酬モデリング、および RLHF が含まれています。RLHF には、RAFT アルゴリズム (報酬 rAnked FineTuning) または TRL-PPO アライメント モデルを介した 2 つの部分が含まれています。ユーザーの便宜のために、GPT-Neo-2.7B に基づく報酬モデルを Hugging Face リポジトリに提供しました。そのため、報酬モデリングを最初にスキップすることもできます。

この例は、非営利使用のみが許可されている LLaMA に基づいています。LLaMA-7B モデルを使用するには、前のリクエスト フォームに記入する必要があります。テスト環境は8×A100(40G)です。

1.1 環境の準備

LMFlowのインストールパッケージにはTRLも含まれているので、公式の例に従ってLMFlowをインストールするだけです。

git clone https://github.com/OptimalScale/LMFlow.git
cd LMFlow
conda create -n lmflow python=3.9 -y
conda activate lmflow
conda install mpi4py
pip install -e .

上記のインストールでは、依存する PyTorch およびその他のパッケージが自動的にインストールされますが、さらに matplotlib パッケージも手動でインストールされます

1.2 データセットの説明

 例としてDahoas/full-hh-rlhfデータセットを使用します 。各データセット サンプルには、アシスタントからの 1 つのプロンプトと 2 つの応答が含まれています。特に、「選択」とラベル付けされた応答は、「拒否」とラベル付けされた応答よりも人間によって好まれます。データセットには、112,000 のトレーニング サンプルと 12,5,000 のテスト サンプルが含まれています。データセットのサンプル例を次に示します。

" Human: What kind of noises did dinosaurs make? Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be Human: yes they did Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. Human: you cant read Assistant: 

Chosen response: "You can read?"

Rejected response: "there’s a lot of stuff humans don’t know"

トレーニングの目的で、モデルが応答することを認識できるように、文字の先頭に「###」を追加してプロンプトを再構築します。新しいサンプルの形式は次のようになります。

"###Human: What kind of noises did dinosaurs make? ###Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be ###Human: yes they did ###Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. ###Human: you cant read ###Assistant: 

Chosen response: "You can read?"

Rejected response: "there’s a lot of stuff humans don’t know"

使用する必要があるすべてのデータ セットを ./data/hh_rlhf ディレクトリに準備しました。これらのデータ セットは、LMFlow ディレクトリで次のコマンドを実行して取得する必要があります。

cd data && ./download.sh hh_rlhf && cd -

2 RLHF 前の準備

このセクションでは、最初に SFT モデルと報酬モデルのトレーニングを完了する必要があり、これを完了するには LMFlow を使用します。

2.1 SFT

データセット /home/usrname/LMFlow/data/hh_rlhf/sft/hh_rlhf_sft.json の例を次に示します。優先応答のみを使用するため、112K のトレーニング サンプルが得られます。

{"type": "text_only", "instances": [{"text": "###Human: Should you buy a case to protect your cell phone?###Assistant: It depends on your circumstances.  If you carry your phone in a pocket or a purse then you probably want a case.  But if you only need a phone for quick interactions, a case may actually cause more harm than good.  What do you need the phone for?  Are you a parent, or do you work from home?###Human: What harm could it do?###Assistant: A phone case can damage the screen, for one thing.  It can also get you in trouble if you have your phone turned off for some reason.  Then you will turn it back on and it won’t do anything.  If you can afford to replace it, then you need a case to protect it.  The problem is that most people aren’t able to afford to replace their phones all the time.###Human: Thanks for letting me know.###Assistant: You’re welcome."}, {"text": "###Human: I'm trying to learn about the salam witch trials###Assistant: If you’re looking for books about witchcraft trials, I can recommend some books for you. But you probably mean you’re looking for more historical information about the Salem witch trials in 1692, and specifically about the salam witch trials in 1692?###Human: What are some good books to learn about the salam witch trials###Assistant: What would you like to learn?  If you’re interested in history, one of the best books is The Witch Persecutions in Colonial America: A History.  If you’re interested in witchcraft as a cultural phenomenon, you might enjoy two excellent books: Religion and the Decline of Magic: Studies in Popular Beliefs in Sixteenth- and Seventeenth-Century England by Keith Thomas and Magic, Witchcraft, and the Otherworld: An Anthropology of Superstition by Jack Goody.  If you’re interested in history specifically as it relates to religion, you might enjoy The Popish Plot, or Prelates' Plot: A History of the Popish Plot in England, by K. J. Everett."}]}

/scripts/run_finetune.shを編集し てパラメータを変更できます 。ここでは例として GPT-Neo-2.7B を使用します。これを、入手した llama-7b モデルのアドレスに置き換える必要があります。

  • --モデル名またはパス: EleutherAI/gpt-neo-2.7B

  • --dataset_path: ${project_dir}/data/hh_rlhf/sft

  • --output_dir: sft モデルを保存するパス

  • --num_train_epochs: 1

  • --learning_rate: 2e-5

  • --per_device_train_batch_size: GPU リソースに応じて調整します。

  • exp_id: hh_rlhf_llama_sft

/scripts/run_finetune.shを編集し てパラメータを変更できます 。ここでは例として GPT-Neo-2.7B を使用します。

次に、次のコマンドを実行して SFT を実行します。

./scripts/run_finetune.sh

次のコマンドを使用して lora トレーニングを使用することもできますが、 run_finetune_with_lora.shを編集して model_name_or_path とデータセットを設定する必要もあります  。

./scripts/run_finetune_with_lora.sh

以下の損失画像の例では、エポックを 4 に設定しましたが、早期に停止し、エポック終了モデルを SFT モデルとして使用しました。また、ロギング ステップは 20 に設定されているため、全体的な外観はよりスムーズになります。

0b01f1ecdbdf6533db77e791419e43e0.png

SFT モデルのトレーニング カーブ。この例では、1.6 エポックのトレーニング カーブをインターセプトします。

私の場合、結果の SFT モデルは /home/usrname/LMFlow/output_models/hh_rlhf_llama_sft/checkpoint-1271に保存されます。

2.2 報酬モデリング

まず、InstructGPT ペーパー: https://arxiv.org/abs/2203.02155 のプロセスに従い、HH-RLHF データセットを使用して報酬モデルをトレーニングします。これには以下が含まれます。

  • 教師あり微調整 (SFT)。

  • データセットを比較することでモデリングに報酬を与えます。

PPO のメモリ負荷が大きいため、この例の設定では、TRL の実装が 7B RM と 7B トレーニング モデルを同時にロードできないことが追跡実験で証明されているため、GPT-Neo-2.7B を使用することを選択します。私たちのRMとして。教師付き微調整はセクション 2.1 と完全に一致しているため、省略します。

報酬モデリングには、与えられた手がかりに基づいて応答をランク付けすることが含まれ、多くの場合人間による注釈が必要になります。幸いなことに、この例では、各プロンプトに対して 2 つの応答が提供されており、そのうちの 1 つがもう 1 つより優れています。新しいデータセットを作成するには、元のデータセットを変更して、「ポジティブ」と「ネガティブ」という 2 つのキーを含めます。前者が優先される応答です。

これは、2 つのサンプルを含むデータセットLMFlow/data/hh_rlhf/rm/hh_rlhf_rm_training.jsonの例です  。完全なデータセットには、Dahoas/full-hh-rlhf データセット 112K トレーニング セットと 12.5K テスト セットが含まれています。

{"instances": [
    {"positive": "###Human: Are space heaters safe?###Assistant: What do you mean? Can you give me more details?###Human: I mean, can they be a fire hazard?###Assistant: I don’t know about that.  Can you tell me more about the situation? Are you using the heater on a fireplace?  Do you have any other heater in the room?  What materials are in the room?  Do you feel any heat from the heater?  How are you touching the heater?###Human: Can you just tell me if there are any situations where they could be a fire hazard?###Assistant: The answer is “it depends”.  It would depend on the specific heater you are using, how you are using it, what kind of fuel it uses, what materials you have in the room, and what the surroundings are like.  You’d also have to make sure you’re using it correctly.  Is it switched off?  Is it safe to touch?  It’s not the sort of thing I could just answer by looking at the text you provided.###Human: Thank you. That information is helpful.###Assistant: Glad to be helpful.  I’m always happy to assist.", "negative": "###Human: Are space heaters safe?###Assistant: What do you mean? Can you give me more details?###Human: I mean, can they be a fire hazard?###Assistant: I don’t know about that.  Can you tell me more about the situation? Are you using the heater on a fireplace?  Do you have any other heater in the room?  What materials are in the room?  Do you feel any heat from the heater?  How are you touching the heater?###Human: Can you just tell me if there are any situations where they could be a fire hazard?###Assistant: The answer is “it depends”.  It would depend on the specific heater you are using, how you are using it, what kind of fuel it uses, what materials you have in the room, and what the surroundings are like.  You’d also have to make sure you’re using it correctly.  Is it switched off?  Is it safe to touch?  It’s not the sort of thing I could just answer by looking at the text you provided.###Human: Thank you. That information is helpful.###Assistant: You’re welcome. I’m glad to help. Is there anything else you want to know?"}, 
    {"positive": "###Human: how do i teach my kids to make forts at home?###Assistant: Do you mean you want to teach them to build a simple structure they can hide in? Or something more elaborate?  I don’t know exactly what you mean by a “fort”, but maybe this gives you some ideas?", "negative": "###Human: how do i teach my kids to make forts at home?###Assistant: There are many ways to make forts!  The simplest might be to just pile up all the furniture in one room.  This works if you have a single room in your house, although it might still be a bit challenging to build a taller and sturdier fort this way.  The best way to build forts is to start with something easy, and then have fun exploring how you can improve it."}
]}

./scripts/run_reward_modeling.shスクリプトを編集して  、前のステップの sft モデルからトレーニングを開始し、「dataset_path」を更新して、使用する目的のデータセットを変更できます。デフォルトでトレーニングに LoRA を使用します。複数のモデルを比較したためです。LoRA は最適化の利点により多くの時間を節約できます。また、完全なトレーニングで得られたモデルは、大幅に有利な RM を提供できないため、パラメータは次のとおりです。続く

  • --model_name_or_path: /home/usrname/LMFlow/output_models/hh_rlhf_rm_sft_gptneo_2_7B/checkpoint-1659

  • --dataset_path: ${project_dir}/data/hh_rlhf/rm/hh_rlhf_rm_training.json

  • --output_dir: 報酬モデルを保存するパス

  • --num_train_epochs: 1

  • --learning_rate: 3e-5

  • --per_device_train_batch_size: GPU メモリ ソースに応じて調整します。

  • --eval_steps: 400

  • --validation_split_percentage: 10

このうち、データ セットの最後の 10% サンプルを自動的に使用して RM をテストします。ここで使用されるデータ セットは、元のデータ セットのトレーニング セット + テスト セットであるため、データ セットの最後の部分はまだ使用されていないことに注意してください。モデルさんに見られました。この例では、 validation_split_percentage を 15 より大きく設定しないでください。そうしないと、SFT で使用されるサンプルの一部がテスト セットで使用されます。これらのデータ セットの処理は、/examples/run_reward_modeling.py で実装されます。 独自 のデータセットは RM のトレーニングに使用され、ニーズに応じてここで変更できます。最後に、次のコードをトレーニングに使用します。

./scripts/run_reward_modeling.sh

以下は、GPT-Neo-2.7B モデルと LLaMA-7B モデルの学習プロセス中の評価損失と評価精度のグラフです。

9767b30bfdc1b7bd3b2f592ac48f627d.png

報酬モデルのトレーニング中の評価曲線。

入手した RM の例

モデル 評価精度 備考
ラマ-7B 79.52% -
ラマ-7B 71.64% SFT なしの LLaMA からの RM
GPT-NEO-2.7B 69.24% -
GPT-NEO-1.3B 65.58% 10000 サンプルのみでトレーニング済み

一般に、大きなモデルの精度も高いことがわかりますが、TRL-PPO は OOM 問題を爆発させるため (クラスメートのフィードバックによると、7B+7B トレーニング trlx の実装も OOM を爆発させるとのことです)、 2.7Bのモデルを使用します。LLaMA-7B モデルの精度ですら約 80% にしか達せず、取得された RM では一部の不要なパターン (繰り返しなど) を検出できない可能性があることに注意してください。それでも比較的高い報酬が得られます。全体として、現在の分類用報酬モデルには依然として大きな欠陥があります。

最後に、取得したモデルは低ランクの LoRA アダプターであるため、*./examples/merge_lora.py* を使用して最終的な RM モデルを取得する必要があります。

3 RAFT アライメント

オリジナル論文: RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment

3.1 アルゴリズムの概要

RAFT のアイデアの起源は、これまでの多くの研究で、トレーニング RM データセットを直接 SFT に使用した場合、最初に RM をトレーニングしてから報酬学習に RL を使用するよりも効果が劣ることがわかっています。1 つの説明は、後者はトレーニング用により多くのデータを保持できるというものですが、データの前方生成自体は PPO に限定されたものではないことに注意してください。また、当時は PPO の調整に多くの時間を費やしましたが、PPO トレーニングは OOM になりやすく、不安定で、モデルの効果が不確実であることがわかりました (記録の途中でさまざまな落とし穴を次のセクションで記録します)また、SFT が垂直場におけるモデルのパフォーマンスを安定して向上させることができることを多くの実験で確認しています。報酬学習に SFT を使用できるかどうかは自然なアイデアです。

具体的には、各ラウンドのトレーニング用に b 個の新しいサンプルを取得したいと考えています。

  • この目的のために、プロンプト セットから bxk プロンプトを選択し、それらを現在のモデルに入力して、対応する出力を取得します。

  • 次に、bxk サンプルの報酬を計算します。

  • SFT トレーニングでは、報酬率が最も高い 1/k のサンプルを選択します。

    • ''top'': 最初の方法は、すべてのサンプルを並べ替えて選択することです。

    • ''ローカル'': 2 番目の方法は、各プロンプトを k 回繰り返し、これらの k サンプルから最も高い報酬を持つサンプルを選択することです。

    • 最初の方法の方が効率的ですが、一部のシナリオ (この例の実験など) では、プロンプト間の比較は無意味であり、ローカルでの並べ替えの方が合理的です。

  • 新しいラウンドが始まります。

ここでは、モデルによって出力されたデータのごく一部のみをトレーニングに使用します。これは、前方操作には適していませんが、後方操作には適しています。deepspeed に基づいた実装では、前方のバッチ サイズが後方の約 5 倍まで拡張できることが観察されたため、1 つの推論のコストは比較的小さいはずだと考えられます。

3.2 例

例として、以前に取得した LLaMA-7B-SFT モデルをトレーニングに使用します。落とし穴のいくつかを説明するために、特定の実験プロセスを記録したいと考えています。そのため、以下では冗長で失敗した試行が多数あります。

データの準備

私たちのトレーニング プロンプト セットは、応答のないDahoas/full-hh-rlhfトレーニング セット内の 112K サンプルです。次に例を示します。

"###Human: Should you buy a case to protect your cell phone?###Assistant: It depends on your circumstances.  If you carry your phone in a pocket or a purse then you probably want a case.  But if you only need a phone for quick interactions, a case may actually cause more harm than good.  What do you need the phone for?  Are you a parent, or do you work from home?###Human: What harm could it do?###Assistant: A phone case can damage the screen, for one thing.  It can also get you in trouble if you have your phone turned off for some reason.  Then you will turn it back on and it won’t do anything.  If you can afford to replace it, then you need a case to protect it.  The problem is that most people aren’t able to afford to replace their phones all the time.###Human: Thanks for letting me know.###Assistant:"

さらに、テスト用にテスト セットから 2K を抽出します。ただし、このプロンプト セットを TRL-PPO トレーニングに使用した場合 (公正な比較のために後で実験をやり直しました、涙)、コードは実行できるものの、OOM は常に 2 番目のエポックで爆発することがわかりました。長時間デバッグした結果、その理由は、一部のプロンプトが非常に長く、生成されるテキストも比較的長いためであることがわかりました。TRL-PPO に必要なメモリはパスの長さに正の相関があるため、トークン数が 256 未満のプロンプトを生成し、最終的に 82147 個のプロンプトを取得します。

LLaMA-7B-SFT をテストする

最初に SFT モデルをテストしたところ、このモデルが対話履歴に対して複数回の自問自答を返すことが判明したため、生成された応答を「###Human」で切り捨てました。

def _clean_text(self, text):
    split_text = [x for x in text.split("###Human") if x]
    return split_text[0].strip().strip("#")

LMFlow では、使用される RM は */LMFlow/examples/raft_align.py* で指定されます。使用する報酬モデルがセクション 2 の方法に従ってトレーニングされている場合は、そのローカル アドレスまたは Hugging Face リポジトリ ID を指定するだけです。

reward_model_or_path: Optional[str] = field(
    default="weqweasdas/hh_rlhf_rm",
    metadata={
        "help": (
            "reward model name (huggingface) or its path"
        ),
    },
)

ただし、Hugging Face の一部の分類子など、RM が一般的な場合は、「get_reward_function」関数を少し変更する必要がある場合もあります。

3.2.1 最初のトレーニング

LMFlow ディレクトリでのトレーニングには次のコマンドとパラメーターを使用します。

./scripts/run_raft_align.sh
  • --model_name_or_path: /home/usrname/output_models/hh_rlhf_llama-sft (モデルは sft ステップから取得され、セットアップに従って調整されます)

  • --dataset_path:${project_dir}/data/hh_rlhf/rlhf/rlhf_prompt

  • --output_dir: /home/usrname/output_models/hh_rlhf_raft_align

  • --num_train_epochs: 4

  • --learning_rate: 2e-5

  • --per_device_train_batch_size: GPU メモリ ソースに応じて調整します。

  • --inference_batch_size_per_device: GPU メモリ ソースに応じて調整します。

  • --num_raft_iteration 20

  • --top_reward_percentage 0.125; (つまり 1/8)

  • --raft_batch_size 1024 (最終的に各ラウンドにはトレーニング用の 1024 個のサンプルが含まれます)

  • --output_min_length 126

実験はスムーズに実行され、トレーニング報酬は約 2.7 から 3.4 に増加しました。トレーニング中、モデル出力のいくつかの多様性指標をモニタリングしたところ、一部の指標 (distinct-2 など) がトレーニング中に 0.39 から大幅に低下したことに気付きました。 0.22まで下がります。一部の研究では、調整税が(人間の選好を改善するコストとして)RLHF モデルの指標の悪化につながることが示されていますが、これほど大幅な減少は依然として異例です。この目的を達成するために、各反復で生成したサンプルを検査したところ、SFT テストと同様に、最初の反復では、最初のチェックポイントへの応答に # が時折含まれているのに対し (サンプルの約 3%)、ランダム # は検出できないことがわかりました。これは、# を含む応答にも高い報酬が与えられ、トレーニング セットに選択される可能性があることを意味します。その後、状況はどんどん悪化し、最終的には応答の半分にノイズの多い # 記号が含まれるようになりました。

3.2.2 2回目の訓練

上記の問題を解決するために、コードを修正し、各サンプルの応答に冗長な # が含まれているかどうかを検出し、含まれている場合は手動で低い報酬に修正しました。同時に、現在の実装では、トレーニング プロセス全体を監視するために、各ラウンドの SFT に使用されるデータ セットを出力します。コードを変更すると、次の報酬曲線が得られます (テスト中に比較的低い温度を使用するため、テストの報酬は高くなるはずであることに注意してください)。

58fdfe1a1fe29850866c39b9465d06ab.png

RAFT のトレーニング報酬曲線。横軸は 1) データ生成 + 2) 報酬計算とサンプルソート + 3) SFT のラウンドを表します。

横軸は、1) データ生成 2) データソート 3) および選択されたデータセットに対する SFT のラウンドを含むラフト反復を表します。この例では、各ラウンドで 8192 個のサンプルが生成され、1024 個のサンプルが SFT に使用されます。トレーニングの開始時には、トレーニング データ セット内のサンプル (黄色の線) がモデル自体の報酬よりもはるかに高く、この小さなデータ セットに対する SFT の後、モデルの報酬が上昇し始めることがわかります (緑の線と青の線)、収集されたトレーニング データが向上します (黄色の線も上昇します)。上記のトレーニングは、8 x A100 (40G) で約 3 時間かかります。

最終的に得られたモデルは、報酬と多様性の尺度の両方で良好なパフォーマンスを示します。興味のある読者は、詳細について元の論文を参照することをお勧めします。ただし、これはむしろ私たちの旅の出発点に近いものであり、ディスカッションの最後の部分で結果についてさらに議論する前に、まず TRL-PPO をどのように実験したかを文書化します。

4 TRL-PPO の調整

LMFlow のインストール中に、TRL もインストールされるため、実験を直接開始できます。3 か月前、TRL を実行したい場合は、いくつかの小さなバグを手動で修正する必要がありました。ここ数日で、最新バージョンをインストールしてテストしたところ、修正されたようです。

データの準備

まず、TRL-PPO が提供するスクリプトのデータ セットの準備を変更します。TRL-PPO スクリプトを LMFlow/examples に配置していることに注意してください。そうでない場合は、次のデータ セットの場所を少し変更する必要があります。

def build_dataset(config, tokenizer, dataset_name="./data/hh_rlhf/rlhf/rlhf_prompt/prompt.json"):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """

    ds = load_dataset("json", data_files=dataset_name, split="train")['instances'][0]
    texts = [sample['text'] for sample in ds]
    from datasets import Dataset
    ds = Dataset.from_dict({
        "text":texts,
    })
    
    
    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["text"])[:]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds = ds.filter(lambda x: len(x["input_ids"]) <= 256)
    ds.set_format(type="torch")
    print(len(ds))
    return ds

ここではプロンプト データ セットをフィルタリングし、長さを 256 トークン以内にのみ維持していることに注意してください。そうでない場合、テキストが長すぎると OOM エラーが発生します。

ハイパーパラメータ調整

PPO は比較的ハイパーパラメータに依存しますが、いくつかの実験を行った結果、TRL のデフォルトのパラメータはすでに非常に優れていると感じました。学習率を慎重に調整しても、大きな改善を得るのは困難です。変更された内容は次のとおりです。

  • バッチサイズ: 1024/n_gpu、設定では 128;

  • mini_batch_size: 興味深い発見は、PPO の更新バッチ サイズが通常 SFT の更新バッチ サイズよりもはるかに小さいため、処理が大幅に遅くなるということですが、それがコード実装の問題によるものなのか、PPO 自体がより多くの中間変数を必要とするためなのかは不明です。 ;

  • gradient_accumulation_steps: 1

さらに重要なのがKL重みの設定で、当初は単純に探索しようと考えていたのですが、0.1、0.05、0.01と数ラウンド走らせても収束しませんでした(報酬が1時間上昇した後に突然崩れてしまいました)一方、または明らかな報酬の増加はありません)。最終的に、私の選択は、最初に KL の係数を 0 に設定し、次に TRL の ppo_trainer の compute_rewards 関数を変更して、この場合の KL 推定値を出力することです。

def compute_rewards(
        self,
        scores: torch.FloatTensor,
        logprobs: torch.FloatTensor,
        ref_logprobs: torch.FloatTensor,
        masks: torch.LongTensor,
    ):
        """
        Compute per token rewards from scores and KL-penalty.

        Args:
            scores (`torch.FloatTensor`):
                Scores from the reward model, shape (`batch_size`)
            logprobs (`torch.FloatTensor`):
                Log probabilities of the model, shape (`batch_size`, `response_length`)
            ref_logprobs (`torch.FloatTensor`):
                Log probabilities of the reference model, shape (`batch_size`, `response_length`)
        """
        cnt = 0
        rewards, non_score_rewards = [], []
        for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
            # compute KL penalty (from difference in logprobs)
            kl = logprob - ref_logprob
            non_score_reward = -self.kl_ctl.value * kl
            non_score_rewards.append(non_score_reward)
            reward = non_score_reward.clone()
            last_non_masked_index = mask.nonzero()[-1]

            # reward is preference model score + KL penalty
            reward[last_non_masked_index] += score
            rewards.append(reward)
            if cnt < 20:
                print(torch.sum(kl))
                cnt += 1
        return torch.stack(rewards), torch.stack(non_score_rewards)

最後に、報酬曲線の後期段階では、KL オフセットが 500 から 600 にも達する可能性があることが判明し、最終的に比較的小さい KL=0.001 を設定することにしました (論文 [1] と一致)。いくつかの実験では、学習率が小さいほど、困惑度の指標が大幅に向上することがわかりました。[1] で設定された学習率ははるかに小さく、記事で報告されている最大 KL オフセットはわずか約 100 または 200 であることは注目に値します。学習率 5-e6 を試してみましたが、結論は次のとおりです。トレーニングがかなり遅くなります (トレーニングに 1 日以上かかります)、KL オフセットに大きな改善はありません。時間の制約があるため、学習率を下げることは試していません。それが効果的かどうかはわかりません。ハイパーパラメータ設定の問題または TRL-PPO と [1] で実装された違い。トレーニングが軌道に乗っているかどうかを監視するために、常にいくつかのサンプルをサンプリングし、その KL 推定値を確認することをお勧めします。

さらに、モデルの応答が短すぎる場合があり、ppo_trainer の次のチェックでエラーが報告されます。1 つの方法は、エラーを直接コメント アウトすることです。もう 1 つの方法は、サンプルをテストして応答が短すぎるサンプルを破棄することです。両方の方法を試しましたが、ほぼ同じ効果があるようです。

def batched_forward_pass(
    ......
    
    if len(logprobs[j, start:end]) < 2:
    	raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.")
    
    ......

TRL-PPO では KL を推定する必要があるため、生成された設定を自由に調整することはできません。そうしないと、おそらく KL の推定に影響を与える可能性があることに注意してください。

generation_kwargs = {
    # "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000,
}

たとえば、上記の短すぎる応答の問題を解決するために、モデルに長い応答を出力させるために最小出力長を設定しようとしましたが、設定後、KL 推定値のほぼ半分が負になることがわかりました。

訓練

PPO のトレーニングでは、モデル自体が質問と回答を行って複数回の応答を生成する問題も発生しますが、この場合はトレーニングできないため、それに応じて出力全体も切り捨てられます。それに応じて戻り値を切り詰める必要があります。

output_min_length = 64
output_max_length = 128
output_length_sampler = LengthSampler(output_min_length, output_max_length)
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 1}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    with torch.no_grad():
        response_tensors = ppo_trainer.generate(
            query_tensors, 
            batch_size=1, ## adjust according to your memory source 
            return_prompt=False, 
            length_sampler=output_length_sampler, 
            **generation_kwargs)

    full_responses = tokenizer.batch_decode(response_tensors)
    clean_texts = [clean_text(tmp_text) for tmp_text in full_responses]
    clean_response_tensors = [tokenizer.encode(text) for text in clean_texts]
    lengths = [len(clean_tensor) for clean_tensor in clean_response_tensors]

    response_tensors = [response_tensors[i][:np.max([lengths[i]-2, 1])] for i in range(len(response_tensors))]

    batch["response"] = clean_texts

    texts_for_rewards = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts_for_rewards, **sent_kwargs)
    rewards = [output[0]["score"] for output in pipe_outputs]

多くの調整を経て、得られた PPO モデルには奇妙なパターンがいくつかあります。まず、PPO モデルには出力に多数のランダム # が混在するため、これらのサンプルを破棄する検出を追加するか、手動でパラメータを与える必要があります。比較的ネガティブな報酬でしたが、参加後、PPO モデルがランダムな # を出力する現象が緩和され、その結果、PPO が「:)」のような表情を繰り返すようになりました。 :) 動作が多かったので、PPO は ;) を繰り返すようになりました。幸いなことに、後の 2 つの問題はそれほど深刻ではなく、比率も比較的低く、許容範囲内です。DRL 自体は比較的ブラックボックスな手法であるため、モデルがこれらの表情を生成する傾向がある理由を直接知ることはできませんが、おそらく RM はこの種の表情を好み、PPO が RM の欠点を利用しているのではないかと推測します。

TRL-PPO はデフォルトでランダムな生成長を使用します。出力長を 128 に固定する方法と、[64, 128] から出力長をランダムに抽出する 2 つの方法を試したところ、他の設定が適切な場合に学習が向上することがわかりました。ただし、後者の方が出力の重複を避けるのに役立つようで、最終的なモデル出力の見栄えが良くなります。

PPO では主にパラメータの調整に時間がかかり、パラメータが適切な場合、トレーニング セッションには 8 ~ 12 時間ほどかかります。

5 件のディスカッション

以下にランダムサンプリングの例をいくつか示します。PPO と RAFT の両方がモデルの応答スタイルを大幅に変更していることがわかります。全体として、RAFT と連携したモデルは通常、より詳細な返答をする傾向があり、PPO モデルはより丁寧で前向きですが、SFT モデルは十分に役に立たないようで、指示どおりにアドバイスを提供しないことがよくあります。同時に、PPO が意味のない記号を出力する場合があり、RAFT の応答に冗長な単語が含まれる場合があることも観察されました。

これは、報酬モデルが応答の質を完全に特徴付けることができず、PPO と RAFT の両方が報酬モデルのこの不完全性をある程度利用して高い報酬を獲得するためであると考えられます。明らかに、これは RLHF 探索の出発点にすぎず、改善の余地はまだたくさんあります。モデルのパフォーマンスをさらに向上させるために、たとえば、報酬モデルを改善したり (LLaMA-7B-RM の使用など)、より高度な生成戦略を試して、生成されるテキストの品質を向上させることもできます (対照検索など)。 、https://zhuanlan .zhihu.com/p/629920420を参照)。それまでの間、LLM をさらに楽しむために LMFlow フレームワークをチェックしてください。

OptimalScale/LMFlow: An Extensible Toolkit for Finetuning and Inference of Large Foundation Models. Large Model for All. (github.com)
https://github.com/OptimalScale/LMFlow

(以下の図は表示の便宜上、表から変換されています)プロンプト内の ### は改行に置き換えられ、太字で表示されます)

b41ef37ca8d11692d81da80c315486ac.png

83972c82998b2808b1a547bc57f4a2ec.png

889c805335ede0b98c06b15c709efc7a.png

6d3b4740f2e18f7e41a659e536ed9207.png

d3fb1c94aae475b185be50aab92ede44.png

[1] 人間のフィードバックからの強化学習を使用して、有益で無害な 326 アシスタントをトレーニングする



著者によって承認された、Hugging Face アカウントは、WeChat パブリック アカウント プラットフォームでのオリジナルのリリースを示します。転載したい場合は、この記事の下にメッセージを残してください。

著者のZhihuアカウント:「教師を尊重し、張北海を教えます」は、友好的な交流と議論を歓迎します。


NLP グループに入る —> NLP 交換グループに参加する

おすすめ

転載: blog.csdn.net/qq_27590277/article/details/131318568