Neuron Trainer

NeuronTrainer

The NeuronTrainerarrow-up-right class provides an extended API for the feature-complete Transformers Trainerarrow-up-right. It is used in all the example scriptsarrow-up-right.

The NeuronTrainerarrow-up-right class is optimized for 🌍 Transformers models running on AWS Trainium.

Here is an example of how to customize NeuronTrainerarrow-up-right to use a weighted loss (useful when you have an unbalanced training set):

Copied

from torch import nn
from optimum.neuron import NeuronTrainer


class CustomNeuronTrainer(NeuronTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

Another way to customize the training loop behavior for the PyTorch NeuronTrainerarrow-up-right is to use callbacksarrow-up-right that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).

NeuronTrainer

class optimum.neuron.NeuronTrainer

<source>arrow-up-right

( *args **kwargs )

Trainer that is suited for performing training on AWS Tranium instances.

class optimum.neuron.Seq2SeqNeuronTrainer

<source>arrow-up-right

( *args **kwargs )

Seq2SeqTrainer that is suited for performing training on AWS Tranium instances.

Last updated