How to add support for new architectures?
Adding BetterTransformer support for new architectures
You want to add a new model for Better Transformer
, the fast path of PyTorch Transformer API? Check this guideline!
Models that should be supported
In theory, any model that has a transformer encoder layer, similar to the classic encoder described in the βAttention Is All You Needβ paper should be supported. More specifically, a model that has an encoder block with a MultiHead-Attention module (with pre or post-attention layer norm) should be convertible to its BetterTransformer
equivalent. The conditions can be summarized as follows:
Use classic Multi Head attention module (for example, DeBERTa cannot be supported)
Use either
gelu
orrelu
activation functionHave an even number of attention heads
Do not use any attention bias (for eg
T5
uses attention bias, therefore cannot be supported)eps
must be equal between the first and second layer norms for each layer
How to convert a model into its BetterTransformer format?
Step 1: Identifying the source layer to change
First, go to optimum/bettertransformer/__init__.py
and youβll see the dictionary BetterTransformerManager.MODEL_MAPPING
. This should contain the mapping between a model type, and the Tuple[str, BetterTransformerBaseLayer]
composed of the name of the nn.Module
that can be converted to its BetterTransformer
equivalent, and effectively the equivalent BetterTransformer
layer class.
Let us try to do it step by step for Bert
, first we need to identify the layers that needs to be replaced:
Copied
You can clearly see that the layers that need to be replaced are the BertLayer
modules since they contain the whole encoder layer module.
Step 2: Building the xxxLayerBetterTransformer module
Check that the identified module is not already copied from another module (by inspecting the source code in transformers
and checking that the class definition does not start with # Copied from ...
) - and if not, create a class in bettertransformer/models/encoder_model.py
. Start with those lines:
Copied
Now, make sure to fill all the necessary attributes, the list of attributes are:
in_proj_weight
in_proj_bias
out_proj_weight
out_proj_bias
linear1_weight
linear1_bias
linear2_weight
linear2_bias
norm1_eps
norm1_weight
norm1_bias
norm2_weight
norm2_bias
num_heads
embed_dim
Note that these attributes correspond to all the components that are necessary to run a Transformer Encoder module, check the figure 1 on the βAttention Is All You Needβ paper.
Once you filled all these attributes (sometimes the query
, key
and value
layers needs to be βcontigufiedβ, check the modeling_encoder.py
file to understand more.)
Make sure also to add the lines:
Copied
Step 3: Building the forward pass
First of all, start with the line super().forward_checker()
, this is needed so that the parent class can run all the safety checkers before.
After the first forward pass, the hidden states needs to be nested using the attention mask. Once they are nested, the attention mask is not needed anymore, therefore can be set to None
. This is how the forward pass is built for Bert
, these lines should remain pretty much similar accross models, but sometimes the shapes of the attention masks are different across models.
Copied
Once the hidden_states
are nested, call torch._transformer_encoder_layer_fwd
using the right arguments as follows:
Copied
At the last layer, it is important to βun-nestβ the hidden_states so that it can be processed by the next modules, this is done in these lines:
Copied
Also make sure to return a tuple
to follow the convention of transformers
.
The best way to reproduce this experiment on your own model is to try it by get some inspiration from the provided modeling scripts. Of course, we will be happy to help you converting your model if you open an issue or a Pull Request on optimum
!
Step 4: Sanity check!
As a last step, make sure to update the BetterTransformerManager.MODEL_MAPPING
dictionary in optimum/bettertransformer/__init__.py
with the correct names, and you should be ready to convert your model. For example, for Bert that would be:
Copied
Try it out with the conversion method that is presented in the tutorials sections!
Last updated