Neuron Trainer
NeuronTrainer
from torch import nn
from optimum.neuron import NeuronTrainer
class CustomNeuronTrainer(NeuronTrainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else lossNeuronTrainer
class optimum.neuron.NeuronTrainer
class optimum.neuron.Seq2SeqNeuronTrainer
Last updated