Inference on many GPUs
Efficient Inference on a Multiple GPUs
This document contains information on how to efficiently infer on a multiple GPUs.
Note: A multi GPU setup can use the majority of the strategies described in the single GPU section. You must be aware of simple techniques, though, that can be used for a better usage.
Flash Attention 2
Flash Attention 2 integration also works in a multi-GPU setup, check out the appropriate section in the single GPU section
BetterTransformer
BetterTransformer converts 🌍 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
BetterTransformer is also supported for faster inference on single and multi-GPU for text, image, and audio models.
Flash Attention can only be used for models using fp16 or bf16 dtype. Make sure to cast your model to the appropriate dtype before using BetterTransformer.
Decoder models
For text models, especially decoder-based models (GPT, T5, Llama, etc.), the BetterTransformer API converts all attention operations to use the torch.nn.functional.scaled_dot_product_attention
operator (SDPA) that is only available in PyTorch 2.0 and onwards.
To convert a model to BetterTransformer:
Copied
SDPA can also call Flash Attention kernels under the hood. To enable Flash Attention or to check that it is available in a given setting (hardware, problem size), use torch.backends.cuda.sdp_kernel
as a context manager:
Copied
If you see a bug with a traceback saying
Copied
try using the PyTorch nightly version, which may have a broader coverage for Flash Attention:
Copied
Have a look at this blog post to learn more about what is possible with the BetterTransformer + SDPA API.
Encoder models
For encoder models during inference, BetterTransformer dispatches the forward call of encoder layers to an equivalent of torch.nn.TransformerEncoderLayer
that will execute the fastpath implementation of the encoder layers.
Because torch.nn.TransformerEncoderLayer
fastpath does not support training, it is dispatched to torch.nn.functional.scaled_dot_product_attention
instead, which does not leverage nested tensors but can use Flash Attention or Memory-Efficient Attention fused kernels.
More details about BetterTransformer performance can be found in this blog post, and you can learn more about BetterTransformer for encoder models in this blog.
Advanced usage: mixing FP4 (or Int8) and BetterTransformer
You can combine the different methods described above to get the best performance for your model. For example, you can use BetterTransformer with FP4 mixed-precision inference + flash attention:
Copied
Last updated