Inference on many GPUs
Last updated
Last updated
This document contains information on how to efficiently infer on a multiple GPUs.
Note: A multi GPU setup can use the majority of the strategies described in the . You must be aware of simple techniques, though, that can be used for a better usage.
Flash Attention 2 integration also works in a multi-GPU setup, check out the appropriate section in the
converts 🌍 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
BetterTransformer is also supported for faster inference on single and multi-GPU for text, image, and audio models.
Flash Attention can only be used for models using fp16 or bf16 dtype. Make sure to cast your model to the appropriate dtype before using BetterTransformer.
For text models, especially decoder-based models (GPT, T5, Llama, etc.), the BetterTransformer API converts all attention operations to use the (SDPA) that is only available in PyTorch 2.0 and onwards.
To convert a model to BetterTransformer:
Copied
Copied
If you see a bug with a traceback saying
Copied
try using the PyTorch nightly version, which may have a broader coverage for Flash Attention:
Copied
Because torch.nn.TransformerEncoderLayer
fastpath does not support training, it is dispatched to torch.nn.functional.scaled_dot_product_attention
instead, which does not leverage nested tensors but can use Flash Attention or Memory-Efficient Attention fused kernels.
You can combine the different methods described above to get the best performance for your model. For example, you can use BetterTransformer with FP4 mixed-precision inference + flash attention:
Copied
SDPA can also call kernels under the hood. To enable Flash Attention or to check that it is available in a given setting (hardware, problem size), use as a context manager:
Have a look at this to learn more about what is possible with the BetterTransformer + SDPA API.
For encoder models during inference, BetterTransformer dispatches the forward call of encoder layers to an equivalent of that will execute the fastpath implementation of the encoder layers.
More details about BetterTransformer performance can be found in this , and you can learn more about BetterTransformer for encoder models in this .