Inference pipelines with AWS Neuron
Last updated
Last updated
The pipeline()
function makes it simple to use models from the for accelerated inference on a variety of tasks such as text classification, question answering and image classification.
You can also use the function from Transformers and provide your NeurModel model class.
Currently the supported tasks are:
feature-extraction
fill-mask
text-classification
token-classification
question-answering
zero-shot-classification
While each task has an associated pipeline class, it is simpler to use the general pipeline()
function which wraps all the task-specific pipelines in one object. The pipeline()
function automatically loads a default model and tokenizer/feature-extractor capable of performing inference for your task.
Start by creating a pipeline by specifying an inference task:
Copied
Pass your input text/image to the pipeline()
function:
Copied
Note: The default models used in the pipeline()
function are not optimized for inference or quantized, so there wonβt be a performance improvement compared to their PyTorch counterparts.
To be able to load the model with the Neuron Runtime, the export to neuron needs to be supported for the considered architecture.
Once you have picked an appropriate model, you can create the pipeline()
by specifying the model repo:
Copied
It is also possible to load it with the from_pretrained(model_name_or_path, export=True)
method associated with the NeuronModelForXXX
class.
For example, here is how you can load the ~neuron.NeuronModelForQuestionAnswering
class for question answering:
Copied
NeuronModels currently require static input_shapes
to run inference. The default input shapes will be used if you are not providing input shapes when providing the export=True
parameter. Below is an example of how to specify the input shapes for the sequence length and batch size.
Copied
The pipeline()
function accepts any supported model from the . There are tags on the Model Hub that allow you to filter for a model youβd like to use for your task.
You can check the list of supported architectures .