ONNX configurations

Configuration classes for ONNX exports

Exporting a model to ONNX involves specifying:

  1. The input names.

  2. The output names.

  3. The dynamic axes. These refer to the input dimensions can be changed dynamically at runtime (e.g. a batch size or sequence length). All other axes will be treated as static, and hence fixed at runtime.

  4. Dummy inputs to trace the model. This is needed in PyTorch to record the computational graph and convert it to ONNX.

Since this data depends on the choice of model and task, we represent it in terms of configuration classes. Each configuration class is associated with a specific model architecture, and follows the naming convention ArchitectureNameOnnxConfig. For instance, the configuration which specifies the ONNX export of BERT models is BertOnnxConfig.

Since many architectures share similar properties for their ONNX configuration, 🌍 Optimum adopts a 3-level class hierarchy:

  1. Abstract and generic base classes. These handle all the fundamental features, while being agnostic to the modality (text, image, audio, etc).

  2. Middle-end classes. These are aware of the modality, but multiple can exist for the same modality depending on the inputs they support. They specify which input generators should be used for the dummy inputs, but remain model-agnostic.

  3. Model-specific classes like the BertOnnxConfig mentioned above. These are the ones actually used to export models.

Base classes

class optimum.exporters.onnx.OnnxConfig

<source>

( config: PretrainedConfigtask: str = 'feature-extraction'preprocessors: typing.Optional[typing.List[typing.Any]] = Noneint_dtype: str = 'int64'float_dtype: str = 'fp32'legacy: bool = False )

Parameters

  • config (transformers.PretrainedConfig) — The model configuration.

  • task (str, defaults to "feature-extraction") — The task the model should be exported for.

  • int_dtype (str, defaults to "int64") — The data type of integer tensors, could be [“int64”, “int32”, “int8”], default to “int64”.

  • float_dtype (str, defaults to "fp32") — The data type of float tensors, could be [“fp32”, “fp16”, “bf16”], default to “fp32”.

Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.

Class attributes:

  • NORMALIZED_CONFIG_CLASS (Type) — A class derived from NormalizedConfig specifying how to normalize the model config.

  • DUMMY_INPUT_GENERATOR_CLASSES (Tuple[Type]) — A tuple of classes derived from DummyInputGenerator specifying how to create dummy inputs.

  • ATOL_FOR_VALIDATION (Union[float, Dict[str, float]]) — A float or a dictionary mapping task names to float, where the float values represent the absolute tolerance value to use during model conversion validation.

  • DEFAULT_ONNX_OPSET (int, defaults to 11) — The default ONNX opset to use for the ONNX export.

  • MIN_TORCH_VERSION (packaging.version.Version, defaults to ~optimum.exporters.onnx.utils.TORCH_MINIMUM_VERSION) — The minimum torch version supporting the export of the model to ONNX.

  • MIN_TRANSFORMERS_VERSION (packaging.version.Version, defaults to ~optimum.exporters.onnx.utils.TRANSFORMERS_MINIMUM_VERSION — The minimum transformers version supporting the export of the model to ONNX. Not always up-to-date or accurate. This is more for internal use.

  • PATCHING_SPECS (Optional[List[PatchingSpec]], defaults to None) — Specify which operators / modules should be patched before performing the export, and how. This is useful when some operator is not supported in ONNX for instance.

inputs

<source>

( ) → Dict[str, Dict[int, str]]

Returns

Dict[str, Dict[int, str]]

A mapping of each input name to a mapping of axis position to the axes symbolic name.

Dict containing the axis definition of the input tensors to provide to the model.

outputs

<source>

( ) → Dict[str, Dict[int, str]]

Returns

Dict[str, Dict[int, str]]

A mapping of each output name to a mapping of axis position to the axes symbolic name.

Dict containing the axis definition of the output tensors to provide to the model.

generate_dummy_inputs

<source>

( framework: str = 'pt'**kwargs ) → Dict

Parameters

  • framework (str, defaults to "pt") — The framework for which to create the dummy inputs.

  • batch_size (int, defaults to 2) — The batch size to use in the dummy inputs.

  • sequence_length (int, defaults to 16) — The sequence length to use in the dummy inputs.

  • num_choices (int, defaults to 4) — The number of candidate answers provided for multiple choice task.

  • image_width (int, defaults to 64) — The width to use in the dummy inputs for vision tasks.

  • image_height (int, defaults to 64) — The height to use in the dummy inputs for vision tasks.

  • num_channels (int, defaults to 3) — The number of channels to use in the dummpy inputs for vision tasks.

  • feature_size (int, defaults to 80) — The number of features to use in the dummpy inputs for audio tasks in case it is not raw audio. This is for example the number of STFT bins or MEL bins.

  • nb_max_frames (int, defaults to 3000) — The number of frames to use in the dummpy inputs for audio tasks in case the input is not raw audio.

  • audio_sequence_length (int, defaults to 16000) — The number of frames to use in the dummpy inputs for audio tasks in case the input is raw audio.

Returns

Dict

A dictionary mapping the input names to dummy tensors in the proper framework format.

Generates the dummy inputs necessary for tracing the model. If not explicitely specified, default input shapes are used.

class optimum.exporters.onnx.OnnxConfigWithPast

<source>

( config: PretrainedConfigtask: str = 'feature-extraction'int_dtype: str = 'int64'float_dtype: str = 'fp32'use_past: bool = Falseuse_past_in_inputs: bool = Falsepreprocessors: typing.Optional[typing.List[typing.Any]] = Nonelegacy: bool = False )

Inherits from OnnxConfig. A base class to handle the ONNX configuration of decoder-only models.

add_past_key_values

<source>

( inputs_or_outputs: typing.Dict[str, typing.Dict[int, str]]direction: str )

Parameters

  • inputs_or_outputs (Dict[str, Dict[int, str]]) — The mapping to fill.

  • direction (str) — either “inputs” or “outputs”, it specifies whether input_or_outputs is the input mapping or the output mapping, this is important for axes naming.

Fills input_or_outputs mapping with past_key_values dynamic axes considering the direction.

class optimum.exporters.onnx.OnnxSeq2SeqConfigWithPast

<source>

( config: PretrainedConfigtask: str = 'feature-extraction'int_dtype: str = 'int64'float_dtype: str = 'fp32'use_past: bool = Falseuse_past_in_inputs: bool = Falsebehavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'>preprocessors: typing.Optional[typing.List[typing.Any]] = Nonelegacy: bool = False )

Inherits from OnnxConfigWithPast. A base class to handle the ONNX configuration of encoder-decoder models.

with_behavior

<source>

( behavior: typing.Union[str, optimum.exporters.onnx.base.ConfigBehavior]use_past: bool = Falseuse_past_in_inputs: bool = False ) → OnnxSeq2SeqConfigWithPast

Parameters

  • behavior (ConfigBehavior) — The behavior to use for the new instance.

  • use_past (bool, defaults to False) — Whether or not the ONNX config to instantiate is for a model using KV cache.

  • use_past_in_inputs (bool, defaults to False) — Whether the KV cache is to be passed as an input to the ONNX.

Returns

OnnxSeq2SeqConfigWithPast

Creates a copy of the current OnnxConfig but with a different ConfigBehavior and use_past value.

Middle-end classes

Text

class optimum.exporters.onnx.TextEncoderOnnxConfig

<source>

( config: PretrainedConfigtask: str = 'feature-extraction'preprocessors: typing.Optional[typing.List[typing.Any]] = Noneint_dtype: str = 'int64'float_dtype: str = 'fp32'legacy: bool = False )

Handles encoder-based text architectures.

class optimum.exporters.onnx.TextDecoderOnnxConfig

<source>

( config: PretrainedConfigtask: str = 'feature-extraction'int_dtype: str = 'int64'float_dtype: str = 'fp32'use_past: bool = Falseuse_past_in_inputs: bool = Falsepreprocessors: typing.Optional[typing.List[typing.Any]] = Nonelegacy: bool = False )

Handles decoder-based text architectures.

class optimum.exporters.onnx.TextSeq2SeqOnnxConfig

<source>

( config: PretrainedConfigtask: str = 'feature-extraction'int_dtype: str = 'int64'float_dtype: str = 'fp32'use_past: bool = Falseuse_past_in_inputs: bool = Falsebehavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'>preprocessors: typing.Optional[typing.List[typing.Any]] = Nonelegacy: bool = False )

Handles encoder-decoder-based text architectures.

Vision

class optimum.exporters.onnx.config.VisionOnnxConfig

<source>

( config: PretrainedConfigtask: str = 'feature-extraction'preprocessors: typing.Optional[typing.List[typing.Any]] = Noneint_dtype: str = 'int64'float_dtype: str = 'fp32'legacy: bool = False )

Handles vision architectures.

Multi-modal

class optimum.exporters.onnx.config.TextAndVisionOnnxConfig

<source>

( config: PretrainedConfigtask: str = 'feature-extraction'preprocessors: typing.Optional[typing.List[typing.Any]] = Noneint_dtype: str = 'int64'float_dtype: str = 'fp32'legacy: bool = False )

Handles multi-modal text and vision architectures.

Last updated