How to perform inference on large models with small resources

Handling big models for inference

One of the biggest advancements ๐ŸŒ Accelerate provides is the concept of large model inference wherein you can perform inference on models that cannot fully fit on your graphics card.

This tutorial will be broken down into two parts showcasing how to use both ๐ŸŒ Accelerate and ๐ŸŒ Transformers (a higher API-level) to make use of this idea.

Using ๐ŸŒ Accelerate

For these tutorials, weโ€™ll assume a typical workflow for loading your model in such that:

Copied

import torch

my_model = ModelClass(...)
state_dict = torch.load(checkpoint_file)
my_model.load_state_dict(state_dict)

Note that here we assume that ModelClass is a model that takes up more video-card memory than what can fit on your device (be it mps or cuda).

The first step is to init an empty skeleton of the model which wonโ€™t take up any RAM using the init_empty_weights() context manager:

Copied

from accelerate import init_empty_weights
with init_empty_weights():
    my_model = ModelClass(...)

With this my_model currently is โ€œparameterlessโ€, hence leaving the smaller footprint than what one would normally get loading this onto the CPU directly.

Next we need to load in the weights to our model so we can perform inference.

For this we will use load_checkpoint_and_dispatch(), which as the name implies will load a checkpoint inside your empty model and dispatch the weights for each layer across all the devices you have available (GPU/MPS and CPU RAM).

To determine how this dispatch can be performed, generally specifying device_map="auto" will be good enough as ๐ŸŒ Accelerate will attempt to fill all the space in your GPU(s), then loading them to the CPU, and finally if there is not enough RAM it will be loaded to the disk (the absolute slowest option).

For more details on desigining your own device map, see this section of the concept guide

See an example below:

Copied

from accelerate import load_checkpoint_and_dispatch

model = load_checkpoint_and_dispatch(
    model, checkpoint=checkpoint_file, device_map="auto"
)

If there are certain โ€œchunksโ€ of layers that shouldnโ€™t be split, you can pass them in as no_split_module_classes. Read more about it here

Also to save on memory (such as if the state_dict will not fit in RAM), a modelโ€™s weights can be divided and split into multiple checkpoint files. Read more about it here

Now that the model is dispatched fully, you can perform inference as normal with the model:

Copied

input = torch.randn(2,3)
input = input.to("cuda")
output = model(input)

What will happen now is each time the input gets passed through a layer, it will be sent from the CPU to the GPU (or disk to CPU to GPU), the output is calculated, and then the layer is pulled back off the GPU going back down the line. While this adds some overhead to the inference being performed, through this method it is possible to run any size model on your system, as long as the largest layer is capable of fitting on your GPU.

Multiple GPUs can be utilized, however this is considered โ€œmodel parallismโ€ and as a result only one GPU will be active at a given moment, waiting for the prior one to send it the output. You should launch your script normally with python and not need torchrun, accelerate launch, etc.

For a visual representation of this, check out the animation below:

Complete Example

Below is the full example showcasing what we performed above:

Copied

import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

with init_empty_weights():
    model = MyModel(...)

model = load_checkpoint_and_dispatch(
    model, checkpoint=checkpoint_file, device_map="auto"
)

input = torch.randn(2,3)
input = input.to("cuda")
output = model(input)

Using ๐ŸŒ Transformers, ๐ŸŒ Diffusers, and other ๐ŸŒ Open Source Libraries

Libraries that support ๐ŸŒ Accelerate big model inference include all of the earlier logic in their from_pretrained constructors.

These operate by specifying a string representing the model to download from the ๐ŸŒ Hub and then denoting device_map="auto" along with a few extra parameters.

As a brief example, we will look at using transformers and loading in Big Scienceโ€™s T0pp model.

Copied

from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto")

After loading the model in, the initial steps from before to prepare a model have all been done and the model is fully ready to make use of all the resources in your machine. Through these constructors, you can also save more memory by specifying the precision the model is loaded into as well, through the torch_dtype parameter, such as:

Copied

from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto", torch_dtype=torch.float16)

To learn more about this, check out the ๐ŸŒ Transformers documentation available here.

Where to go from here

For a much more detailed look at big model inference, be sure to check out the Conceptual Guide on it

Last updated