do you know? Now train any Llama-2 model on your own data with just a few lines of code!
And even models with 7 billion parameters can be used on a single A100 GPU, thanks to the magic of 4bit and PEFT!
Fine-tuning the language model through PPO roughly includes three steps:
Rollout : The language model generates a response or continuous text based on a query, which may be the beginning of a sentence.
Evaluation: The query and generated responses are evaluated using a function, model, human feedback, or a combination thereof. Importantly, the process should produce a scalar value for each query/response pair.
Optimization: This is the most complex part. In the optimization step, the log probability of each token in the sequence is computed using the query and generated response pairs. This is done with the already trained model and a reference model, which is usually a pre-trained model before fine-tuning. The KL divergence between the two model outputs is used as an additional reward signal to ensure that the generated responses do not deviate too far from the reference language model. Then use PPO to train the main language model.
This process is illustrated in the schematic diagram below:
Install:
To install Python libraries via pip:
pip install trl
If you want to run the examples in the library from source, some additional libraries need to be installed. First, you need to clone the repository for the library, then install it using pip.
git clone https://github.com/lvwerra/trl.gitcd trl/pip install .
If you wish to develop TRL (referring to the Python library named "trl" mentioned above), you can install it using editable mode
pip install -e .
how to use:
SFTTrainer
Here is a basic example of how to use the SFTTrainer from the library. SFTTrainer is a lightweight wrapper of Transformers Trainer that makes it easy to fine-tune language models or adapters on custom datasets.
# imports
from datasets import load_dataset
from trl import SFTTrainer
# get datasetdataset = load_dataset("imdb", split="train")
# get trainer
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,)
# train
trainer.train()
RewardTrainer
Below is a basic example of how to use the RewardTrainer from the library.
# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizerfrom trl import RewardTrainer
# load model and dataset - dataset needs to be in a specific formatmodel = AutoModelForSequenceClassification.from_pretrained("gpt2")tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
# load trainer
trainer = RewardTrainer(
model=model, t
okenizer=tokenizer,
train_dataset=dataset,)
# train
trainer.train()
PPOTrainer
Below is a basic example of how to use the PPOTrainer from the library. From a query, a language model creates a response, and then evaluates that response. The evaluation can be done by a human, or it can be the output of another model.
# importssi
mport torch
from transformers import AutoTokenizerf
rom trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch
# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
# initialize trainer
ppo_config = PPOConfig(
batch_size=1,)
# encode a queryquery_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
# get model response
response_tensor = respond_to_batch(model, query_tensor)
# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
# define a reward for response
# (this could be any reward such as human feedback or output from another model)reward = [torch.tensor(1.0)]
# train model for one step with ppotrain_stats = ppo_t
rainer.step([query_tensor[0]], [response_tensor[0]], reward)
Advanced Example: IMDB Sentiment Classification
For a detailed example, check out the example Python script in the project examples/scripts/sentiment_tuning.py. Here are a few examples extracted from the language model before and after optimization: GitHub project address: https://github.com/lvwerra/trl