Google Cloud Trainium & Inferentia
  • 🌍Optimum Neuron
  • 🌍Installation
  • 🌍Quickstart
  • 🌍TUTORIALS
    • Overview
    • Fine-tune BERT for Text Classification on AWS Trainium
  • 🌍HOW-TO GUIDES
    • Overview
    • Set up AWS Trainium instance
    • Neuron model cache
    • Fine-tune Transformers with AWS Trainium
    • Export a model to Inferentia
    • Neuron models for inference
    • Inference pipelines with AWS Neuron
  • 🌍REFERENCE
    • Neuron Trainer
    • Neuron Exporter
    • Neuron Models
Powered by GitBook
On this page
  • NeuronTrainer
  • NeuronTrainer
  1. REFERENCE

Neuron Trainer

PreviousREFERENCENextNeuron Exporter

Last updated 1 year ago

NeuronTrainer

The class provides an extended API for the feature-complete . It is used in all the .

The class is optimized for 🌍 Transformers models running on AWS Trainium.

Here is an example of how to customize to use a weighted loss (useful when you have an unbalanced training set):

Copied

from torch import nn
from optimum.neuron import NeuronTrainer


class CustomNeuronTrainer(NeuronTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

NeuronTrainer

class optimum.neuron.NeuronTrainer

( *args **kwargs )

Trainer that is suited for performing training on AWS Tranium instances.

class optimum.neuron.Seq2SeqNeuronTrainer

( *args **kwargs )

Seq2SeqTrainer that is suited for performing training on AWS Tranium instances.

Another way to customize the training loop behavior for the PyTorch is to use that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).

🌍
NeuronTrainer
Transformers Trainer
example scripts
NeuronTrainer
NeuronTrainer
NeuronTrainer
callbacks
<source>
<source>