FlaxAutoModel
FlaxAutoModel
class transformers.FlaxAutoModel
( *args**kwargs )
This is a generic model class that will be instantiated as one of the base model classes of the library when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__() (throws an error).
from_config
( **kwargs )
Parameters
- config (PretrainedConfig) β The model class to instantiate is selected based on the configuration class: - AlbertConfig configuration class: FlaxAlbertModel (ALBERT model) 
- BartConfig configuration class: FlaxBartModel (BART model) 
- BeitConfig configuration class: FlaxBeitModel (BEiT model) 
- BertConfig configuration class: FlaxBertModel (BERT model) 
- BigBirdConfig configuration class: FlaxBigBirdModel (BigBird model) 
- BlenderbotConfig configuration class: FlaxBlenderbotModel (Blenderbot model) 
- BlenderbotSmallConfig configuration class: FlaxBlenderbotSmallModel (BlenderbotSmall model) 
- BloomConfig configuration class: FlaxBloomModel (BLOOM model) 
- CLIPConfig configuration class: FlaxCLIPModel (CLIP model) 
- DistilBertConfig configuration class: FlaxDistilBertModel (DistilBERT model) 
- ElectraConfig configuration class: FlaxElectraModel (ELECTRA model) 
- GPT2Config configuration class: FlaxGPT2Model (OpenAI GPT-2 model) 
- GPTJConfig configuration class: FlaxGPTJModel (GPT-J model) 
- GPTNeoConfig configuration class: FlaxGPTNeoModel (GPT Neo model) 
- LongT5Config configuration class: FlaxLongT5Model (LongT5 model) 
- MBartConfig configuration class: FlaxMBartModel (mBART model) 
- MT5Config configuration class: FlaxMT5Model (MT5 model) 
- MarianConfig configuration class: FlaxMarianModel (Marian model) 
- OPTConfig configuration class: FlaxOPTModel (OPT model) 
- PegasusConfig configuration class: FlaxPegasusModel (Pegasus model) 
- RegNetConfig configuration class: FlaxRegNetModel (RegNet model) 
- ResNetConfig configuration class: FlaxResNetModel (ResNet model) 
- RoFormerConfig configuration class: FlaxRoFormerModel (RoFormer model) 
- RobertaConfig configuration class: FlaxRobertaModel (RoBERTa model) 
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormModel (RoBERTa-PreLayerNorm model) 
- T5Config configuration class: FlaxT5Model (T5 model) 
- ViTConfig configuration class: FlaxViTModel (ViT model) 
- VisionTextDualEncoderConfig configuration class: FlaxVisionTextDualEncoderModel (VisionTextDualEncoder model) 
- Wav2Vec2Config configuration class: FlaxWav2Vec2Model (Wav2Vec2 model) 
- WhisperConfig configuration class: FlaxWhisperModel (Whisper model) 
- XGLMConfig configuration class: FlaxXGLMModel (XGLM model) 
- XLMRobertaConfig configuration class: FlaxXLMRobertaModel (XLM-RoBERTa model) 
 
Instantiates one of the base model classes of the library from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
Examples:
Copied
>>> from transformers import AutoConfig, FlaxAutoModel
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-cased")
>>> model = FlaxAutoModel.from_config(config)from_pretrained
( *model_args**kwargs )
Parameters
- pretrained_model_name_or_path ( - stror- os.PathLike) β Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like - bert-base-uncased, or namespaced under a user or organization name, like- dbmdz/bert-base-german-cased.
- A path to a directory containing model weights saved using save_pretrained(), e.g., - ./my_model_directory/.
- A path or url to a PyTorch state_dict save file (e.g, - ./pt_model/pytorch_model.bin). In this case,- from_ptshould be set to- Trueand a configuration object should be provided as- configargument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
 
- model_args (additional positional arguments, optional) β Will be passed along to the underlying model - __init__()method.
- config (PretrainedConfig, optional) β Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when: - The model is a model provided by the library (loaded with the model id string of a pretrained model). 
- The model was saved using save_pretrained() and is reloaded by supplying the save directory. 
- The model is loaded by supplying a local directory as - pretrained_model_name_or_pathand a configuration JSON file named config.json is found in the directory.
 
- cache_dir ( - stror- os.PathLike, optional) β Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
- from_pt ( - bool, optional, defaults to- False) β Load the model weights from a PyTorch checkpoint save file (see docstring of- pretrained_model_name_or_pathargument).
- force_download ( - bool, optional, defaults to- False) β Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist.
- resume_download ( - bool, optional, defaults to- False) β Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists.
- proxies ( - Dict[str, str], optional) β A dictionary of proxy servers to use by protocol or endpoint, e.g.,- {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
- output_loading_info( - bool, optional, defaults to- False) β Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
- local_files_only( - bool, optional, defaults to- False) β Whether or not to only look at local files (e.g., not try downloading the model).
- revision ( - str, optional, defaults to- "main") β The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so- revisioncan be any identifier allowed by git.
- trust_remote_code ( - bool, optional, defaults to- False) β Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to- Truefor repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.
- code_revision ( - str, optional, defaults to- "main") β The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so- revisioncan be any identifier allowed by git.
- kwargs (additional keyword arguments, optional) β Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - output_attentions=True). Behaves differently depending on whether a- configis provided or automatically loaded:- If a configuration is provided with - config,- **kwargswill be directly passed to the underlying modelβs- __init__method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, - kwargswill be first passed to the configuration class initialization function (from_pretrained()). Each key of- kwargsthat corresponds to a configuration attribute will be used to override said attribute with the supplied- kwargsvalue. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying modelβs- __init__function.
 
Instantiate one of the base model classes of the library from a pretrained model.
The model class to instantiate is selected based on the model_type property of the config object (either passed as an argument or loaded from pretrained_model_name_or_path if possible), or when itβs missing, by falling back to using pattern matching on pretrained_model_name_or_path:
- albert β FlaxAlbertModel (ALBERT model) 
- bart β FlaxBartModel (BART model) 
- beit β FlaxBeitModel (BEiT model) 
- bert β FlaxBertModel (BERT model) 
- big_bird β FlaxBigBirdModel (BigBird model) 
- blenderbot β FlaxBlenderbotModel (Blenderbot model) 
- blenderbot-small β FlaxBlenderbotSmallModel (BlenderbotSmall model) 
- bloom β FlaxBloomModel (BLOOM model) 
- clip β FlaxCLIPModel (CLIP model) 
- distilbert β FlaxDistilBertModel (DistilBERT model) 
- electra β FlaxElectraModel (ELECTRA model) 
- gpt-sw3 β FlaxGPT2Model (GPT-Sw3 model) 
- gpt2 β FlaxGPT2Model (OpenAI GPT-2 model) 
- gpt_neo β FlaxGPTNeoModel (GPT Neo model) 
- gptj β FlaxGPTJModel (GPT-J model) 
- longt5 β FlaxLongT5Model (LongT5 model) 
- marian β FlaxMarianModel (Marian model) 
- mbart β FlaxMBartModel (mBART model) 
- mt5 β FlaxMT5Model (MT5 model) 
- opt β FlaxOPTModel (OPT model) 
- pegasus β FlaxPegasusModel (Pegasus model) 
- regnet β FlaxRegNetModel (RegNet model) 
- resnet β FlaxResNetModel (ResNet model) 
- roberta β FlaxRobertaModel (RoBERTa model) 
- roberta-prelayernorm β FlaxRobertaPreLayerNormModel (RoBERTa-PreLayerNorm model) 
- roformer β FlaxRoFormerModel (RoFormer model) 
- t5 β FlaxT5Model (T5 model) 
- vision-text-dual-encoder β FlaxVisionTextDualEncoderModel (VisionTextDualEncoder model) 
- vit β FlaxViTModel (ViT model) 
- wav2vec2 β FlaxWav2Vec2Model (Wav2Vec2 model) 
- whisper β FlaxWhisperModel (Whisper model) 
- xglm β FlaxXGLMModel (XGLM model) 
- xlm-roberta β FlaxXLMRobertaModel (XLM-RoBERTa model) 
Examples:
Copied
>>> from transformers import AutoConfig, FlaxAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModel.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModel.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModel.from_pretrained(
...     "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )Last updated
