XLA Integration for TensorFlow Models
XLA Integration for TensorFlow Models
Accelerated Linear Algebra, dubbed XLA, is a compiler for accelerating the runtime of TensorFlow Models. From the official documentation:
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that can accelerate TensorFlow models with potentially no source code changes.
Using XLA in TensorFlow is simple β it comes packaged inside the tensorflow
library, and it can be triggered with the jit_compile
argument in any graph-creating function such as tf.function
. When using Keras methods like fit()
and predict()
, you can enable XLA simply by passing the jit_compile
argument to model.compile()
. However, XLA is not limited to these methods - it can also be used to accelerate any arbitrary tf.function
.
Several TensorFlow methods in π Transformers have been rewritten to be XLA-compatible, including text generation for models such as GPT2, T5 and OPT, as well as speech processing for models such as Whisper.
While the exact amount of speed-up is very much model-dependent, for TensorFlow text generation models inside π Transformers, we noticed a speed-up of ~100x. This document will explain how you can use XLA for these models to get the maximum amount of performance. Weβll also provide links to additional resources if youβre interested to learn more about the benchmarks and our design philosophy behind the XLA integration.
Running TF functions with XLA
Let us consider the following model in TensorFlow:
Copied
The above model accepts inputs having a dimension of (10, )
. We can use the model for running a forward pass like so:
Copied
In order to run the forward pass with an XLA-compiled function, weβd need to do:
Copied
The default call()
function of the model
is used for compiling the XLA graph. But if thereβs any other model function you want to compile into XLA thatβs also possible with:
Copied
Running a TF text generation model with XLA from π Transformers
To enable XLA-accelerated generation within π Transformers, you need to have a recent version of transformers
installed. You can install it by running:
Copied
And then you can run the following code:
Copied
As you can notice, enabling XLA on generate()
is just a single line of code. The rest of the code remains unchanged. However, there are a couple of gotchas in the above code snippet that are specific to XLA. You need to be aware of those to realize the speed-ups that XLA can bring in. We discuss these in the following section.
Gotchas to be aware of
When you are executing an XLA-enabled function (like xla_generate()
above) for the first time, it will internally try to infer the computation graph, which is time-consuming. This process is known as βtracingβ.
You might notice that the generation time is not fast. Successive calls of xla_generate()
(or any other XLA-enabled function) wonβt have to infer the computation graph, given the inputs to the function follow the same shape with which the computation graph was initially built. While this is not a problem for modalities with fixed input shapes (e.g., images), you must pay attention if you are working with variable input shape modalities (e.g., text).
To ensure xla_generate()
always operates with the same input shapes, you can specify the padding
arguments when calling the tokenizer.
Copied
This way, you can ensure that the inputs to xla_generate()
will always receive inputs with the shape it was traced with and thus leading to speed-ups in the generation time. You can verify this with the code below:
Copied
On a Tesla T4 GPU, you can expect the outputs like so:
Copied
The first call to xla_generate()
is time-consuming because of tracing, but the successive calls are orders of magnitude faster. Keep in mind that any change in the generation options at any point with trigger re-tracing and thus leading to slow-downs in the generation time.
We didnβt cover all the text generation options π Transformers provides in this document. We encourage you to read the documentation for advanced use cases.
Additional Resources
Here, we leave you with some additional resources if you want to delve deeper into XLA in π Transformers and in general.
This Colab Notebook provides an interactive demonstration if you want to fiddle with the XLA-compatible encoder-decoder (like T5) and decoder-only (like GPT2) text generation models.
This blog post provides an overview of the comparison benchmarks for XLA-compatible models along with a friendly introduction to XLA in TensorFlow.
This blog post discusses our design philosophy behind adding XLA support to the TensorFlow models in π Transformers.
Recommended posts for learning more about XLA and TensorFlow graphs in general:
Last updated