Neuron Trainer
NeuronTrainer
The NeuronTrainer
class provides an extended API for the feature-complete Transformers Trainer. It is used in all the example scripts.
The NeuronTrainer
class is optimized for π Transformers models running on AWS Trainium.
Here is an example of how to customize NeuronTrainer
to use a weighted loss (useful when you have an unbalanced training set):
Copied
Another way to customize the training loop behavior for the PyTorch NeuronTrainer
is to use callbacks that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platformsβ¦) and take decisions (like early stopping).
NeuronTrainer
class optimum.neuron.NeuronTrainer
( *args **kwargs )
Trainer that is suited for performing training on AWS Tranium instances.
class optimum.neuron.Seq2SeqNeuronTrainer
( *args **kwargs )
Seq2SeqTrainer that is suited for performing training on AWS Tranium instances.
Last updated