# Quickstart

## Quickstart

🌍 Optimum Neuron was designed with one goal in mind: **to make training and inference straightforward for any** 🌍 **Transformers user while leveraging the complete power of AWS Accelerators**.

### Training

There are two main classes one needs to know:

* NeuronArgumentParser: inherits the original [BaArgumentParser](https://huggingface.co/docs/transformers/main/en/internal/trainer_utils#transformers.HfArgumentParser) in Transformers with additional checks on the argument values to make sure that they will work well with AWS Trainium instances.
* [NeuronTrainer](https://huggingface.co/docs/optimum/neuron/package_reference/trainer): the trainer class that takes care of compiling and distributing the model to run on Trainium Chips, and performing training and evaluation.

The [NeuronTrainer](https://huggingface.co/docs/optimum/neuron/package_reference/trainer) is very similar to the 🌍 [Transformers Trainer](https://huggingface.co/docs/transformers/main_classes/trainer), and adapting a script using the Trainer to make it work with Trainium will mostly consist in simply swapping the `Trainer` class for the `NeuronTrainer` one. That’s how most of the [example scripts](https://github.com/huggingface/optimum-neuron/tree/main/examples) were adapted from their [original counterparts](https://github.com/huggingface/transformers/tree/main/examples/pytorch).

modifications:

Copied

```
from transformers import TrainingArguments
-from transformers import Trainer
+from optimum.neuron import NeuronTrainer as Trainer
training_args = TrainingArguments(
  # training arguments...
)
# A lot of code here
# Initialize our Trainer
trainer = Trainer(
    model=model,
    args=training_args,  # Original training arguments.
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
)
```

All Trainium instances come at least with 2 Neuron Cores. To leverage those we need to launch the training whith `torchrun`. Below you see and example of how to launch a training script on a `trn1.2xlarge` instance using a `bert-base-uncased` model.

Copied

```
torchrun --nproc_per_node=2 boincai-neuron-samples/text-classification/run_glue.py \
--model_name_or_path bert-base-uncased \
--dataset_name philschmid/emotion \
--do_train \
--do_eval \
--bf16 True \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir ./bert-emotion
```

### Inference

You can compile and export your 🌍 Transformers models to a serialized format before inference on Neuron devices:

Copied

```
optimum-cli export neuron 
  --model distilbert-base-uncased-finetuned-sst-2-english \
  --batch_size 1 \
  --sequence_length 32 \
  --auto_cast matmul \
  --auto_cast_type bf16 \
  distilbert_base_uncased_finetuned_sst2_english_neuron/
```

The command above will export `distilbert-base-uncased-finetuned-sst-2-english` with static shapes: `batch_size=1` and `sequence_length=32`, and cast all `matmul` operations from FP32 to BF16. Check out the [exporter guide](https://huggingface.co/docs/optimum-neuron/guides/export_model#exporting-a-model-to-neuron-using-the-cli) for more compilation options.

Then you can run the exported Neuron model on Neuron devices with `NeuronModelForXXX` classes which are similar to `AutoModelForXXX` classes in 🌍 Transformers:

Copied

```
from transformers import AutoTokenizer
-from transformers import AutoModelForSequenceClassification
+from optimum.neuron import NeuronModelForSequenceClassification

# PyTorch checkpoint
-model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
+model = NeuronModelForSequenceClassification.from_pretrained("distilbert_base_uncased_finetuned_sst2_english_neuron")

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
inputs = tokenizer("Hamilton is considered to be the best musical of past years.", return_tensors="pt")

logits = model(**inputs).logits
print(model.config.id2label[logits.argmax().item()])
# 'POSITIVE'
```
