Reward Model Training

Reward Modeling

TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.

Check out a complete flexible example inside examples/scripts folder.

Expected dataset format

The RewardTrainer expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the Anthropic/hh-rlhf dataset below:

Therefore the final dataset object should contain two 4 entries at least if you use the default RewardDataCollatorWithPadding data collator. The entries should be named:

  • input_ids_chosen

  • attention_mask_chosen

  • input_ids_rejected

  • attention_mask_rejected

Using the RewardTrainer

After preparing your dataset, you can use the RewardTrainer in the same way as the Trainer class from 🌍 Transformers. You should pass an AutoModelForSequenceClassification model to the RewardTrainer, along with a RewardConfig which configures the hyperparameters of the training.

Leveraging 🌍 PEFT to train a reward model

Just pass a peft_config in the keyword arguments of RewardTrainer, and the trainer should automatically take care of converting the model into a PEFT model!

Copied

from peft import LoraConfig, task_type
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig

model = AutoModelForSequenceClassification.from_pretrained("gpt2")
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

...

trainer = RewardTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=dataset,
    peft_config=peft_config,
)

trainer.train()

Adding a margin to the loss

As in the Llama 2 paper, you can add a margin to the loss by adding a margin column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.

Copied

def add_margin(row):
    # Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
    return {'margin': row['score_chosen'] - row['score_rejected']}

dataset = dataset.map(add_margin)

RewardConfig

class trl.RewardConfig

<source>

( output_dir: stroverwrite_output_dir: bool = Falsedo_train: bool = Falsedo_eval: bool = Falsedo_predict: bool = Falseevaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no'prediction_loss_only: bool = Falseper_device_train_batch_size: int = 8per_device_eval_batch_size: int = 8per_gpu_train_batch_size: typing.Optional[int] = Noneper_gpu_eval_batch_size: typing.Optional[int] = Nonegradient_accumulation_steps: int = 1eval_accumulation_steps: typing.Optional[int] = Noneeval_delay: typing.Optional[float] = 0learning_rate: float = 5e-05weight_decay: float = 0.0adam_beta1: float = 0.9adam_beta2: float = 0.999adam_epsilon: float = 1e-08max_grad_norm: float = 1.0num_train_epochs: float = 3.0max_steps: int = -1lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear'warmup_ratio: float = 0.0warmup_steps: int = 0log_level: typing.Optional[str] = 'passive'log_level_replica: typing.Optional[str] = 'warning'log_on_each_node: bool = Truelogging_dir: typing.Optional[str] = Nonelogging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'logging_first_step: bool = Falselogging_steps: float = 500logging_nan_inf_filter: bool = Truesave_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'save_steps: float = 500save_total_limit: typing.Optional[int] = Nonesave_safetensors: typing.Optional[bool] = Falsesave_on_each_node: bool = Falseno_cuda: bool = Falseuse_cpu: bool = Falseuse_mps_device: bool = Falseseed: int = 42data_seed: typing.Optional[int] = Nonejit_mode_eval: bool = Falseuse_ipex: bool = Falsebf16: bool = Falsefp16: bool = Falsefp16_opt_level: str = 'O1'half_precision_backend: str = 'auto'bf16_full_eval: bool = Falsefp16_full_eval: bool = Falsetf32: typing.Optional[bool] = Nonelocal_rank: int = -1ddp_backend: typing.Optional[str] = Nonetpu_num_cores: typing.Optional[int] = Nonetpu_metrics_debug: bool = Falsedebug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = ''dataloader_drop_last: bool = Falseeval_steps: typing.Optional[float] = Nonedataloader_num_workers: int = 0past_index: int = -1run_name: typing.Optional[str] = Nonedisable_tqdm: typing.Optional[bool] = Noneremove_unused_columns: typing.Optional[bool] = Truelabel_names: typing.Optional[typing.List[str]] = Noneload_best_model_at_end: typing.Optional[bool] = Falsemetric_for_best_model: typing.Optional[str] = Nonegreater_is_better: typing.Optional[bool] = Noneignore_data_skip: bool = Falsefsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = ''fsdp_min_num_params: int = 0fsdp_config: typing.Optional[str] = Nonefsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = Nonedeepspeed: typing.Optional[str] = Nonelabel_smoothing_factor: float = 0.0optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch'optim_args: typing.Optional[str] = Noneadafactor: bool = Falsegroup_by_length: bool = Falselength_column_name: typing.Optional[str] = 'length'report_to: typing.Optional[typing.List[str]] = Noneddp_find_unused_parameters: typing.Optional[bool] = Noneddp_bucket_cap_mb: typing.Optional[int] = Noneddp_broadcast_buffers: typing.Optional[bool] = Nonedataloader_pin_memory: bool = Trueskip_memory_metrics: bool = Trueuse_legacy_prediction_loop: bool = Falsepush_to_hub: bool = Falseresume_from_checkpoint: typing.Optional[str] = Nonehub_model_id: typing.Optional[str] = Nonehub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save'hub_token: typing.Optional[str] = Nonehub_private_repo: bool = Falsehub_always_push: bool = Falsegradient_checkpointing: typing.Optional[bool] = Trueinclude_inputs_for_metrics: bool = Falsefp16_backend: str = 'auto'push_to_hub_model_id: typing.Optional[str] = Nonepush_to_hub_organization: typing.Optional[str] = Nonepush_to_hub_token: typing.Optional[str] = Nonemp_parameters: str = ''auto_find_batch_size: bool = Falsefull_determinism: bool = Falsetorchdynamo: typing.Optional[str] = Noneray_scope: typing.Optional[str] = 'last'ddp_timeout: typing.Optional[int] = 1800torch_compile: bool = Falsetorch_compile_backend: typing.Optional[str] = Nonetorch_compile_mode: typing.Optional[str] = Nonedispatch_batches: typing.Optional[bool] = Noneinclude_tokens_per_second: typing.Optional[bool] = Falsemax_length: typing.Optional[int] = None )

Parameters

  • max_length (int, optional, defaults to None) β€” The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.

  • gradient_checkpointing (bool, optional, defaults to True) β€” If True, use gradient checkpointing to save memory at the expense of slower backward pass.

RewardConfig collects all training arguments related to the RewardTrainer class.

Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.

RewardTrainer

class trl.RewardTrainer

<source>

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = Noneargs: typing.Optional[trl.trainer.training_configs.RewardConfig] = Nonedata_collator: typing.Optional[DataCollator] = Nonetrain_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = Noneeval_dataset: typing.Union[datasets.arrow_dataset.Dataset, typing.Dict[str, datasets.arrow_dataset.Dataset], NoneType] = Nonetokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = Nonemodel_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = Nonecompute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = Nonecallbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = Noneoptimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None)preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = Nonemax_length: typing.Optional[int] = Nonepeft_config: typing.Optional[typing.Dict] = None )

The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the transformers.Trainer class and inherits all of its attributes and methods. It is recommended to use an AutoModelForSequenceClassification as the reward model. The reward model should be trained on a dataset of paired examples, where each example is a tuple of two sequences. The reward model should be trained to predict which example in the pair is more relevant to the task at hand.

The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least if you don’t use the default RewardDataCollatorWithPadding data collator. The entries should be named

  • input_ids_chosen

  • attention_mask_chosen

  • input_ids_rejected

  • attention_mask_rejected

Optionally, you can also pass a margin entry to the dataset. This entry should contain the margin used to modulate the loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. If you don’t pass a margin, no margin will be used.

Last updated