PPO Trainer
PPO Trainer
TRL supports the PPO Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at examples/notebooks/gpt2-sentiment.ipynb
. The trainer is heavily inspired by the original OpenAI learning to summarize work.
The first step is to train your SFT model (see the SFTTrainer), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see RewardTrainer) which will be used to optimize the SFT model using the PPO algorithm.
Expected dataset format
The PPOTrainer
expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
Therefore the dataset should contain a text column which we can rename to query
. Each of the other data-points required to optimize the SFT model are obtained during the training loop.
Here is an example with the HuggingFaceH4/cherry_picked_prompts dataset:
Copied
Resulting in the following subset of the dataset:
Copied
Using the PPOTrainer
For a detailed example have a look at the examples/notebooks/gpt2-sentiment.ipynb
notebook. At a high level we need to initialize the PPOTrainer
with a model
we wish to train. Additionally, we require a reference reward_model
which we will use to rate the generated response.
Initializing the PPOTrainer
The PPOConfig
dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer.
Copied
Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the βPPOTrainer` automatically. The model can be initialized as follows:
Copied
As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using transformers.pipeline
for ease of use.
Copied
Lastly, we pretokenize our dataset using the tokenizer
to ensure we can efficiently generate responses during the training loop:
Copied
Now we are ready to initialize the PPOTrainer
using the defined config, datasets, and model.
Copied
Starting the training loop
Because the PPOTrainer
needs an active reward
per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment reward_model
initialized above.
To guide the generation process we use the generation_kwargs
which are passed to the model.generate
method for the SFT-model during each step. A more detailed example can be found over here.
Copied
We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the reward_model
and pass these rewards to the ppo_trainer.step
method. The ppo_trainer.step
method will then optimize the SFT model using the PPO algorithm.
Copied
Logging
While training and evaluating we log the following metrics:
stats
: The statistics of the PPO algorithm, including the loss, entropy, etc.batch
: The batch of data used to train the SFT model.rewards
: The rewards obtained from the Reward model.
PPOTrainer
class trl.PPOTrainer
( config: PPOConfig = Nonemodel: PreTrainedModelWrapper = Noneref_model: typing.Optional[trl.models.modeling_base.PreTrainedModelWrapper] = Nonetokenizer: PreTrainedTokenizerBase = Nonedataset: typing.Union[torch.utils.data.dataset.Dataset, datasets.arrow_dataset.Dataset, NoneType] = Noneoptimizer: typing.Optional[torch.optim.optimizer.Optimizer] = Nonedata_collator: typing.Optional[typing.Callable] = Nonenum_shared_layers: typing.Optional[int] = Nonelr_scheduler: typing.Optional[torch.optim.lr_scheduler._LRScheduler] = None )
Parameters
**config** (
PPOConfig
) β Configuration object for PPOTrainer. Check the documentation ofPPOConfig
for more β details.**model** (
PreTrainedModelWrapper
) β Model to be optimized, BOINC AI transformer model with a value head. β Check the documentation ofPreTrainedModelWrapper
for more details.**ref_model** (
PreTrainedModelWrapper
, optional) β Reference model to be used for KL penalty, BOINC AI β transformer model with a casual language modelling head. Check the documentation ofPreTrainedModelWrapper
for more details. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized with shared layers.**tokenizer** (
PreTrainedTokenizerBase
) β Tokenizer to be used for encoding the β data. Check the documentation oftransformers.PreTrainedTokenizer
andtransformers.PreTrainedTokenizerFast
for more details.**dataset** (Union[
torch.utils.data.Dataset
,datasets.Dataset
], optional) β PyTorch dataset or BOINC AI dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be created outside the trainer users needs to design their own dataloader and make sure the batch size that is used is the same as the one specified in the configuration object.**optimizer** (
torch.optim.Optimizer
, optional) β Optimizer to be used for training. If no optimizer is β provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration object.**data_collator** (DataCollatorForLanguageModeling, optional) β Data collator to be used for training and β passed along the dataloader
**num_shared_layers** (int, optional) β Number of layers to be shared between the model and the reference β model, if no reference model is passed. If no number is provided, all the layers will be shared.
**lr_scheduler** (
torch.optim.lr_scheduler
, optional) β Learning rate scheduler to be used for training. β
The PPOTrainer uses Proximal Policy Optimization to optimise language models. Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here: https://github.com/openai/summarize-from-feedback
batched_forward_pass
( model: PreTrainedModelWrapperqueries: Tensorresponses: Tensormodel_inputs: dictreturn_logits: bool = Falseresponse_masks: typing.Optional[torch.Tensor] = None ) β (tuple)
Parameters
queries (
torch.LongTensor
) β List of tensors containing the encoded queries, shape (batch_size
,query_length
)responses (
torch.LongTensor
) β List of tensors containing the encoded responses, shape (batch_size
,response_length
)return_logits (
bool
, optional, defaults toFalse
) β Whether to return all_logits. Set toFalse
if logits are not needed to reduce memory consumption.
Returns
(tuple)
all_logprobs (
torch.FloatTensor
): Log probabilities of the responses, shape (batch_size
,response_length
)all_ref_logprobs (
torch.FloatTensor
): Log probabilities of the responses, shape (batch_size
,response_length
)all_values (
torch.FloatTensor
): Values of the responses, shape (batch_size
,response_length
)
Calculate model outputs in multiple batches.
compute_rewards
( scores: FloatTensorlogprobs: FloatTensorref_logprobs: FloatTensormasks: LongTensor )
Parameters
scores (
torch.FloatTensor
) β Scores from the reward model, shape (batch_size
)logprobs (
torch.FloatTensor
) β Log probabilities of the model, shape (batch_size
,response_length
)ref_logprobs (
torch.FloatTensor
) β Log probabilities of the reference model, shape (batch_size
,response_length
)
Compute per token rewards from scores and KL-penalty.
create_model_card
( path: strmodel_name: typing.Optional[str] = 'TRL Model' )
Parameters
path (
str
) β The path to save the model card to.model_name (
str
, optional) β The name of the model, defaults toTRL Model
.
Creates and saves a model card for a TRL model.
gather_stats
( stats ) β dict[str, Any]
Parameters
stats (dict[str, Any]) β
a dictionary of stats to be gathered. The stats should contain torch tensors. β
Returns
dict[str, Any]
A dictionary of stats with the tensors gathered.
Gather stats from all processes. Useful in the context of distributed training.
generate
( query_tensor: typing.Union[torch.Tensor, typing.List[torch.Tensor]]length_sampler: typing.Callable = Nonebatch_size: int = 4return_prompt: bool = True**generation_kwargs ) β torch.LongTensor
Parameters
query_tensor (
torch.LongTensor
) β A tensor of shape (seq_len
) containing query tokens or a list of tensors of shape (seq_len
).generation_kwargs (dict[str, Any]) β Keyword arguments for generation.
length_sampler (
Callable
, optional) β Callable that returns the number of newly generated tokens.batch_size (
int
, *optional) β Batch size used for generation, defaults to4
.return_prompt (
bool
, optional) β If set toFalse
the prompt is not returned but only the newly generated tokens, defaults toTrue
.
Returns
torch.LongTensor
A tensor of shape (batch_size
, gen_len
) containing response tokens.
Generate response with the model given the query tensor. call the generate
method of the model.
log_stats
( stats: dictbatch: dictrewards: typing.List[torch.FloatTensor]columns_to_log: typing.List[str] = ['query', 'response'] )
Parameters
stats (dict[str, Any]) β A dictionary of training stats.
batch (dict[str, Any]) β A dictionary of batch data, this contains the queries and responses.
rewards (
List[torch.FloatTensor]
) β A tensor of rewards.
A function that logs all the training stats. Call it at the end of each epoch.
loss
( old_logprobs: FloatTensorvalues: FloatTensorlogits: FloatTensorvpreds: FloatTensorlogprobs: FloatTensormask: LongTensoradvantages: FloatTensorreturns: FloatTensor )
Parameters
old_logprobs (
torch.FloatTensor
) β Log probabilities of the model, shape (batch_size
,response_length
)values (
torch.FloatTensor
) β Values of the value head, shape (batch_size
,response_length
)rewards (
torch.FloatTensor
) β Rewards from the reward model, shape (batch_size
,response_length
)logits (
torch.FloatTensor
) β Logits of the model, shape (batch_size
,response_length
,vocab_size
)v_pred (
torch.FloatTensor
) β Values of the value head, shape (batch_size
,response_length
)logprobs (
torch.FloatTensor
) β Log probabilities of the model, shape (batch_size
,response_length
)
Calculate policy and value losses.
prepare_dataloader
( dataset: typing.Union[torch.utils.data.dataset.Dataset, datasets.arrow_dataset.Dataset]data_collator = None ) β torch.utils.data.DataLoader
Parameters
dataset (Union[
torch.utils.data.Dataset
,datasets.Dataset
]) β PyTorch dataset or BOINC AI dataset. If a BOINC AI dataset is passed, the dataset will be preprocessed by removing the columns that are not used by the model.data_collator (Optional[function]) β Data collator function.
Returns
torch.utils.data.DataLoader
PyTorch dataloader
Prepare the dataloader for training.
record_step_stats
( kl_coef: float**data ) β stats (dict
)
Parameters
kl_coef (
float
) β KL coefficientdata (
dict
) β Dictionary of training step data
Returns
stats (dict
)
Dictionary of training step statistics
Record training step statistics.
step
( queries: typing.List[torch.LongTensor]responses: typing.List[torch.LongTensor]scores: typing.List[torch.FloatTensor]response_masks: typing.Optional[typing.List[torch.LongTensor]] = None ) β dict[str, Any]
Parameters
queries (List
torch.LongTensor
) β List of tensors containing the encoded queries of shape (query_length
)responses (List
torch.LongTensor
) β List of tensors containing the encoded responses of shape (response_length
)scores (List
torch.FloatTensor
) β List of tensors containing the scores.response_masks (List
torch.FloatTensor
, optional)) β List of tensors containing masks of the response tokens.
Returns
dict[str, Any]
A summary of the training statistics
Run a PPO optimisation step given a list of queries, model responses, and rewards.
train_minibatch
( old_logprobs: FloatTensorvalues: FloatTensorlogprobs: FloatTensorlogits: FloatTensorvpreds: FloatTensormask: LongTensoradvantages: FloatTensorreturns: FloatTensor ) β train_stats (dict[str, torch.Tensor
])
Parameters
logprobs (
torch.FloatTensor
) β Log probabilities of the model, shape [batch_size, response_length]values (
torch.FloatTensor
) β Values of the value head, shape [batch_size, response_length]query (
torch.LongTensor
) β Encoded queries, shape [batch_size, query_length]response (
torch.LongTensor
) β Encoded responses, shape [batch_size, response_length]model_input (
torch.LongTensor
) β Concatenated queries and responses, shape [batch_size, query_length+response_length]
Returns
train_stats (dict[str, torch.Tensor
])
Dictionary of training statistics
Train one PPO minibatch
class trl.PPOConfig
( exp_name: str = 'doc-buil'seed: int = 0log_with: typing.Union[typing.Literal['wandb', 'tensorboard'], NoneType] = Nonetask_name: typing.Optional[str] = Nonemodel_name: typing.Optional[str] = Nonequery_dataset: typing.Optional[str] = Nonereward_model: typing.Optional[str] = Noneremove_unused_columns: bool = Truetracker_kwargs: dict = <factory>accelerator_kwargs: dict = <factory>project_kwargs: dict = <factory>tracker_project_name: str = 'trl'push_to_hub_if_best_kwargs: dict = <factory>steps: int = 20000learning_rate: float = 1e-05adap_kl_ctrl: bool = Trueinit_kl_coef: typing.Optional[float] = 0.2kl_penalty: typing.Literal['kl', 'abs', 'mse', 'full'] = 'kl'target: typing.Optional[float] = 6horizon: typing.Optional[float] = 10000gamma: float = 1lam: float = 0.95cliprange: float = 0.2cliprange_value: float = 0.2vf_coef: float = 0.1batch_size: int = 256forward_batch_size: typing.Optional[int] = Nonemini_batch_size: int = 1gradient_accumulation_steps: int = 1world_size: typing_extensions.Annotated[int, Suppress] = Noneppo_epochs: int = 4max_grad_norm: typing.Optional[float] = Noneoptimize_cuda_cache: bool = Falseearly_stopping: bool = Falsetarget_kl: float = 1compare_steps: int = 1ratio_threshold: float = 10.0use_score_scaling: bool = Falseuse_score_norm: bool = Falsescore_clip: typing.Optional[float] = Noneis_encoder_decoder: typing.Union[typing_extensions.Annotated[bool, Suppress], NoneType] = Noneis_peft_model: typing.Union[typing_extensions.Annotated[bool, Suppress], NoneType] = Nonebackward_batch_size: typing_extensions.Annotated[int, Suppress] = Noneglobal_backward_batch_size: typing_extensions.Annotated[int, Suppress] = Noneglobal_batch_size: typing_extensions.Annotated[int, Suppress] = None )
Configuration class for PPOTrainer
Last updated