Reward Model Training
Reward Modeling
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
Check out a complete flexible example inside examples/scripts folder.
Expected dataset format
The RewardTrainer expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the Anthropic/hh-rlhf dataset below:

Therefore the final dataset object should contain two 4 entries at least if you use the default RewardDataCollatorWithPadding data collator. The entries should be named:
input_ids_chosenattention_mask_choseninput_ids_rejectedattention_mask_rejected
Using the RewardTrainer
After preparing your dataset, you can use the RewardTrainer in the same way as the Trainer class from 🌍 Transformers. You should pass an AutoModelForSequenceClassification model to the RewardTrainer, along with a RewardConfig which configures the hyperparameters of the training.
Leveraging 🌍 PEFT to train a reward model
Just pass a peft_config in the keyword arguments of RewardTrainer, and the trainer should automatically take care of converting the model into a PEFT model!
Copied
from peft import LoraConfig, task_type
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
...
trainer = RewardTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dataset,
peft_config=peft_config,
)
trainer.train()Adding a margin to the loss
As in the Llama 2 paper, you can add a margin to the loss by adding a margin column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
Copied
def add_margin(row):
# Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
return {'margin': row['score_chosen'] - row['score_rejected']}
dataset = dataset.map(add_margin)RewardConfig
class trl.RewardConfig
( output_dir: stroverwrite_output_dir: bool = Falsedo_train: bool = Falsedo_eval: bool = Falsedo_predict: bool = Falseevaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no'prediction_loss_only: bool = Falseper_device_train_batch_size: int = 8per_device_eval_batch_size: int = 8per_gpu_train_batch_size: typing.Optional[int] = Noneper_gpu_eval_batch_size: typing.Optional[int] = Nonegradient_accumulation_steps: int = 1eval_accumulation_steps: typing.Optional[int] = Noneeval_delay: typing.Optional[float] = 0learning_rate: float = 5e-05weight_decay: float = 0.0adam_beta1: float = 0.9adam_beta2: float = 0.999adam_epsilon: float = 1e-08max_grad_norm: float = 1.0num_train_epochs: float = 3.0max_steps: int = -1lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear'warmup_ratio: float = 0.0warmup_steps: int = 0log_level: typing.Optional[str] = 'passive'log_level_replica: typing.Optional[str] = 'warning'log_on_each_node: bool = Truelogging_dir: typing.Optional[str] = Nonelogging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'logging_first_step: bool = Falselogging_steps: float = 500logging_nan_inf_filter: bool = Truesave_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'save_steps: float = 500save_total_limit: typing.Optional[int] = Nonesave_safetensors: typing.Optional[bool] = Falsesave_on_each_node: bool = Falseno_cuda: bool = Falseuse_cpu: bool = Falseuse_mps_device: bool = Falseseed: int = 42data_seed: typing.Optional[int] = Nonejit_mode_eval: bool = Falseuse_ipex: bool = Falsebf16: bool = Falsefp16: bool = Falsefp16_opt_level: str = 'O1'half_precision_backend: str = 'auto'bf16_full_eval: bool = Falsefp16_full_eval: bool = Falsetf32: typing.Optional[bool] = Nonelocal_rank: int = -1ddp_backend: typing.Optional[str] = Nonetpu_num_cores: typing.Optional[int] = Nonetpu_metrics_debug: bool = Falsedebug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = ''dataloader_drop_last: bool = Falseeval_steps: typing.Optional[float] = Nonedataloader_num_workers: int = 0past_index: int = -1run_name: typing.Optional[str] = Nonedisable_tqdm: typing.Optional[bool] = Noneremove_unused_columns: typing.Optional[bool] = Truelabel_names: typing.Optional[typing.List[str]] = Noneload_best_model_at_end: typing.Optional[bool] = Falsemetric_for_best_model: typing.Optional[str] = Nonegreater_is_better: typing.Optional[bool] = Noneignore_data_skip: bool = Falsefsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = ''fsdp_min_num_params: int = 0fsdp_config: typing.Optional[str] = Nonefsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = Nonedeepspeed: typing.Optional[str] = Nonelabel_smoothing_factor: float = 0.0optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch'optim_args: typing.Optional[str] = Noneadafactor: bool = Falsegroup_by_length: bool = Falselength_column_name: typing.Optional[str] = 'length'report_to: typing.Optional[typing.List[str]] = Noneddp_find_unused_parameters: typing.Optional[bool] = Noneddp_bucket_cap_mb: typing.Optional[int] = Noneddp_broadcast_buffers: typing.Optional[bool] = Nonedataloader_pin_memory: bool = Trueskip_memory_metrics: bool = Trueuse_legacy_prediction_loop: bool = Falsepush_to_hub: bool = Falseresume_from_checkpoint: typing.Optional[str] = Nonehub_model_id: typing.Optional[str] = Nonehub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save'hub_token: typing.Optional[str] = Nonehub_private_repo: bool = Falsehub_always_push: bool = Falsegradient_checkpointing: typing.Optional[bool] = Trueinclude_inputs_for_metrics: bool = Falsefp16_backend: str = 'auto'push_to_hub_model_id: typing.Optional[str] = Nonepush_to_hub_organization: typing.Optional[str] = Nonepush_to_hub_token: typing.Optional[str] = Nonemp_parameters: str = ''auto_find_batch_size: bool = Falsefull_determinism: bool = Falsetorchdynamo: typing.Optional[str] = Noneray_scope: typing.Optional[str] = 'last'ddp_timeout: typing.Optional[int] = 1800torch_compile: bool = Falsetorch_compile_backend: typing.Optional[str] = Nonetorch_compile_mode: typing.Optional[str] = Nonedispatch_batches: typing.Optional[bool] = Noneinclude_tokens_per_second: typing.Optional[bool] = Falsemax_length: typing.Optional[int] = None )
Parameters
max_length (
int, optional, defaults toNone) — The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.gradient_checkpointing (
bool, optional, defaults toTrue) — If True, use gradient checkpointing to save memory at the expense of slower backward pass.
RewardConfig collects all training arguments related to the RewardTrainer class.
Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.
RewardTrainer
class trl.RewardTrainer
( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = Noneargs: typing.Optional[trl.trainer.training_configs.RewardConfig] = Nonedata_collator: typing.Optional[DataCollator] = Nonetrain_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = Noneeval_dataset: typing.Union[datasets.arrow_dataset.Dataset, typing.Dict[str, datasets.arrow_dataset.Dataset], NoneType] = Nonetokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = Nonemodel_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = Nonecompute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = Nonecallbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = Noneoptimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None)preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = Nonemax_length: typing.Optional[int] = Nonepeft_config: typing.Optional[typing.Dict] = None )
The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the transformers.Trainer class and inherits all of its attributes and methods. It is recommended to use an AutoModelForSequenceClassification as the reward model. The reward model should be trained on a dataset of paired examples, where each example is a tuple of two sequences. The reward model should be trained to predict which example in the pair is more relevant to the task at hand.
The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least if you don’t use the default RewardDataCollatorWithPadding data collator. The entries should be named
input_ids_chosenattention_mask_choseninput_ids_rejectedattention_mask_rejected
Optionally, you can also pass a margin entry to the dataset. This entry should contain the margin used to modulate the loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. If you don’t pass a margin, no margin will be used.
Last updated