DPO Trainer
Last updated
Last updated
TRL supports the DPO Trainer for training language models from preference data, as described in the paper Direct Preference Optimization: Your Language Model is Secretly a Reward Model by Rafailov et al., 2023. For a full example have a look at examples/dpo.py
.
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the Anthropic/hh-rlhf
dataset below:
Therefore the final dataset object should contain these 3 entries if you use the default DPODataCollatorWithPadding
data collator. The entries should be named:
prompt
chosen
rejected
for example:
Copied
where the prompt
contains the context inputs, chosen
contains the corresponding chosen responses and rejected
contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary’s value arrays.
The DPO trainer expects a model of AutoModelForCausalLM
, compared to PPO that expects AutoModelForCausalLMWithValueHead
for the value function.
For a detailed example have a look at the examples/dpo.py
script. At a high level we need to initialize the DPOTrainer
with a model
we wish to train, a reference ref_model
which we will use to calculate the implicit rewards of the preferred and rejected response, the beta
refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the model
and ref_model
need to have the same architecture (ie decoder only or encoder-decoder).
Copied
After this one can then call:
Copied
Note that the beta
is the temperature parameter for the DPO loss, typically something in the range of 0.1
to 0.5
. We ignore the reference model as beta
-> 0.
While training and evaluating we record the following reward metrics:
rewards/chosen
: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
rewards/rejected
: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
rewards/accuracies
: mean of how often the chosen rewards are > than the corresponding rejected rewards
rewards/margins
: the mean difference between the chosen and corresponding rejected rewards
( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = Noneref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = Nonebeta: float = 0.1args: TrainingArguments = Nonedata_collator: typing.Optional[DataCollator] = Nonelabel_pad_token_id: int = -100padding_value: int = 0truncation_mode: str = 'keep_end'train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = Noneeval_dataset: typing.Union[datasets.arrow_dataset.Dataset, typing.Dict[str, datasets.arrow_dataset.Dataset], NoneType] = Nonetokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = Nonemodel_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = Nonecallbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = Noneoptimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None)preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = Nonemax_length: typing.Optional[int] = Nonemax_prompt_length: typing.Optional[int] = Nonemax_target_length: typing.Optional[int] = Nonepeft_config: typing.Optional[typing.Dict] = Noneis_encoder_decoder: typing.Optional[bool] = Nonedisable_dropout: bool = Truegenerate_during_eval: bool = Falsecompute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalLoopOutput], typing.Dict], NoneType] = None )
Parameters
model (transformers.PreTrainedModel
) — The model to train, preferably an AutoModelForSequenceClassification
.
ref_model (PreTrainedModelWrapper
) — BOINC AI transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
beta (float
, defaults to 0.1) — The beta factor in DPO loss. Higher beta means less divergence from the initial policy.
args (transformers.TrainingArguments
) — The arguments to use for training.
data_collator (transformers.DataCollator
) — The data collator to use for training. If None is specified, the default data collator (DPODataCollatorWithPadding
) will be used which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
label_pad_token_id (int
, defaults to -100
) — The label pad token id. This argument is required if you want to use the default data collator.
padding_value (int
, defaults to 0
) — The padding value. This argument is required if you want to use the default data collator.
truncation_mode (str
, defaults to keep_end
) — The truncation mode to use, either keep_end
or keep_start
. This argument is required if you want to use the default data collator.
train_dataset (datasets.Dataset
) — The dataset to use for training.
eval_dataset (datasets.Dataset
) — The dataset to use for evaluation.
tokenizer (transformers.PreTrainedTokenizerBase
) — The tokenizer to use for training. This argument is required if you want to use the default data collator.
model_init (Callable[[], transformers.PreTrainedModel]
) — The model initializer to use for training. If None is specified, the default model initializer will be used.
callbacks (List[transformers.TrainerCallback]
) — The callbacks to use for training.
optimizers (Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — The function to use to preprocess the logits before computing the metrics.
max_length (int
, defaults to None
) — The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (int
, defaults to None
) — The maximum length of the prompt. This argument is required if you want to use the default data collator.
max_target_length (int
, defaults to None
) — The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
peft_config (Dict
, defaults to None
) — The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
is_encoder_decoder (Optional[bool]
, optional
, defaults to None
) — If no model is provided, we need to know if the model_init returns an encoder-decoder.
disable_dropout (bool
, defaults to True
) — Whether or not to disable dropouts in model
and ref_model
.
generate_during_eval (bool
, defaults to False
) — Whether to sample and log generations during evaluation step.
compute_metrics (Callable[[EvalPrediction], Dict]
, optional) — The function to use to compute the metrics. Must take a EvalPrediction
and return a dictionary string to metric values.
Initialize DPOTrainer.
concatenated_forward
( model: Modulebatch: typing.Dict[str, typing.Union[typing.List, torch.LongTensor]] )
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it’s faster for FSDP.
concatenated_inputs
( batch: typing.Dict[str, typing.Union[typing.List, torch.LongTensor]] )
Concatenate the chosen and rejected inputs into a single tensor.
dpo_loss
( policy_chosen_logps: FloatTensorpolicy_rejected_logps: FloatTensorreference_chosen_logps: FloatTensorreference_rejected_logps: FloatTensorreference_free: bool = False ) → A tuple of three tensors
Returns
A tuple of three tensors
(losses, chosen_rewards, rejected_rewards). The losses tensor contains the DPO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
Compute the DPO loss for a batch of policy and reference model log probabilities.
evaluation_loop
( dataloader: DataLoaderdescription: strprediction_loss_only: typing.Optional[bool] = Noneignore_keys: typing.Optional[typing.List[str]] = Nonemetric_key_prefix: str = 'eval' )
Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by Trainer.evaluate()
and Trainer.predict()
.
Works both with or without labels.
get_batch_metrics
( modelbatch: typing.Dict[str, typing.Union[typing.List, torch.LongTensor]]train_eval: typing.Literal['train', 'eval'] = 'train' )
Compute the DPO loss and other metrics for the given batch of inputs for train or test.
get_batch_samples
( modelbatch: typing.Dict[str, torch.LongTensor] )
Generate samples from the model and reference model for the given batch of inputs.
log
( logs: typing.Dict[str, float] )
Parameters
logs (Dict[str, float]
) — The values to log.
Log logs
on the various objects watching training, including stored metrics.