Working with large models
Working with large models
Dispatching and Offloading Models
accelerate.init_empty_weights
( include_buffers: bool = None )
Parameters
include_buffers (
bool
, optional) โ Whether or not to also put all buffers on the meta device while initializing.
A context manager under which models are initialized with all parameters on the meta device, therefore creating an empty model. Useful when just initializing the model would blow the available RAM.
Example:
Copied
Any model created under this context manager has no weights. As such you canโt do something like model.to(some_device)
with it. To load weights inside your empty model, see load_checkpoint_and_dispatch().
accelerate.cpu_offload
( model: Moduleexecution_device: typing.Optional[torch.device] = Noneoffload_buffers: bool = Falsestate_dict: typing.Union[typing.Dict[str, torch.Tensor], NoneType] = Nonepreload_module_classes: typing.Optional[typing.List[str]] = None )
Parameters
model (
torch.nn.Module
) โ The model to offload.execution_device (
torch.device
, optional) โ The device on which the forward pass of the model will be executed (should be a GPU). Will default to the model first parameter device.offload_buffers (
bool
, optional, defaults toFalse
) โ Whether or not to offload the buffers with the model parameters.state_dict (
Dict[str, torch.Tensor]
, optional) โ The state dict of the model that will be kept on CPU.preload_module_classes (
List[str]
, optional) โ A list of classes whose instances should load all their weights (even in the submodules) at the beginning of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if adense
linear layer is registered, but at forward,dense.weight
anddense.bias
are used in some operations instead of callingdense
directly.
Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that state dict and put on the execution device passed as they are needed, then offloaded again.
accelerate.disk_offload
( model: Moduleoffload_dir: typing.Union[str, os.PathLike]execution_device: typing.Optional[torch.device] = Noneoffload_buffers: bool = Falsepreload_module_classes: typing.Optional[typing.List[str]] = None )
Parameters
model (
torch.nn.Module
) โ The model to offload.offload_dir (
str
oros.PathLike
) โ The folder in which to offload the model weights (or where the model weights are already offloaded).execution_device (
torch.device
, optional) โ The device on which the forward pass of the model will be executed (should be a GPU). Will default to the modelโs first parameter device.offload_buffers (
bool
, optional, defaults toFalse
) โ Whether or not to offload the buffers with the model parameters.preload_module_classes (
List[str]
, optional) โ A list of classes whose instances should load all their weights (even in the submodules) at the beginning of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if adense
linear layer is registered, but at forward,dense.weight
anddense.bias
are used in some operations instead of callingdense
directly.
Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and put on the execution device passed as they are needed, then offloaded again.
accelerate.dispatch_model
( model: Moduledevice_map: typing.Dict[str, typing.Union[int, str, torch.device]]main_device: typing.Optional[torch.device] = Nonestate_dict: typing.Union[typing.Dict[str, torch.Tensor], NoneType] = Noneoffload_dir: typing.Union[str, os.PathLike, NoneType] = Noneoffload_index: typing.Union[typing.Dict[str, str], NoneType] = Noneoffload_buffers: bool = Falseskip_keys: typing.Union[str, typing.List[str], NoneType] = Nonepreload_module_classes: typing.Optional[typing.List[str]] = Noneforce_hooks: bool = False )
Parameters
model (
torch.nn.Module
) โ The model to dispatch.device_map (
Dict[str, Union[str, int, torch.device]]
) โ A dictionary mapping module names in the modelsstate_dict
to the device they should go to. Note that"disk"
is accepted even if itโs not a proper value fortorch.device
.main_device (
str
,int
ortorch.device
, optional) โ The main execution device. Will default to the first device in thedevice_map
different from"cpu"
or"disk"
.state_dict (
Dict[str, torch.Tensor]
, optional) โ The state dict of the part of the model that will be kept on CPU.offload_dir (
str
oros.PathLike
) โ The folder in which to offload the model weights (or where the model weights are already offloaded).offload_index (
Dict
, optional) โ A dictionary from weight name to their information (dtype
/shape
or safetensors filename). Will default to the index saved insave_folder
.offload_buffers (
bool
, optional, defaults toFalse
) โ Whether or not to offload the buffers with the model parameters.skip_keys (
str
orList[str]
, optional) โ A list of keys to ignore when moving inputs or outputs between devices.preload_module_classes (
List[str]
, optional) โ A list of classes whose instances should load all their weights (even in the submodules) at the beginning of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if adense
linear layer is registered, but at forward,dense.weight
anddense.bias
are used in some operations instead of callingdense
directly.force_hooks (
bool
, optional, defaults toFalse
) โ Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a single device.
Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on the CPU or even the disk.
accelerate.load_checkpoint_and_dispatch
( model: Modulecheckpoint: typing.Union[str, os.PathLike]device_map: typing.Union[str, typing.Dict[str, typing.Union[int, str, torch.device]], NoneType] = Nonemax_memory: typing.Union[typing.Dict[typing.Union[int, str], typing.Union[int, str]], NoneType] = Noneno_split_module_classes: typing.Optional[typing.List[str]] = Noneoffload_folder: typing.Union[str, os.PathLike, NoneType] = Noneoffload_buffers: bool = Falsedtype: typing.Union[str, torch.dtype, NoneType] = Noneoffload_state_dict: typing.Optional[bool] = Noneskip_keys: typing.Union[str, typing.List[str], NoneType] = Nonepreload_module_classes: typing.Optional[typing.List[str]] = Noneforce_hooks: bool = False )
Parameters
model (
torch.nn.Module
) โ The model in which we want to load a checkpoint.checkpoint (
str
oros.PathLike
) โ The folder checkpoint to load. It can be:a path to a file containing a whole model state dict
a path to a
.json
file containing the index to a sharded checkpointa path to a folder containing a unique
.index.json
file and the shards of a checkpoint.
device_map (
Dict[str, Union[int, str, torch.device]]
, optional) โ A map that specifies where each submodule should go. It doesnโt need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device.To have Accelerate compute the most optimized
device_map
automatically, setdevice_map="auto"
. For more information about each option see here.max_memory (
Dict
, optional) โ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset.no_split_module_classes (
List[str]
, optional) โ A list of layer class names that should never be split across device (for instance any layer that has a residual connection).offload_folder (
str
oros.PathLike
, optional) โ If thedevice_map
contains any value"disk"
, the folder where we will offload weights.offload_buffers (
bool
, optional, defaults toFalse
) โ In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as well as the parameters.dtype (
str
ortorch.dtype
, optional) โ If provided, the weights will be converted to that type when loaded.offload_state_dict (
bool
, optional) โ IfTrue
, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard does not fit. Will default toTrue
if the device map picked contains"disk"
values.skip_keys (
str
orList[str]
, optional) โ A list of keys to ignore when moving inputs or outputs between devices.preload_module_classes (
List[str]
, optional) โ A list of classes whose instances should load all their weights (even in the submodules) at the beginning of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if adense
linear layer is registered, but at forward,dense.weight
anddense.bias
are used in some operations instead of callingdense
directly.force_hooks (
bool
, optional, defaults toFalse
) โ Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a single device.
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are loaded and adds the various hooks that will make this model run properly (even if split across devices).
Example:
Copied
accelerate.load_checkpoint_in_model
( model: Modulecheckpoint: typing.Union[str, os.PathLike]device_map: typing.Union[typing.Dict[str, typing.Union[int, str, torch.device]], NoneType] = Noneoffload_folder: typing.Union[str, os.PathLike, NoneType] = Nonedtype: typing.Union[str, torch.dtype, NoneType] = Noneoffload_state_dict: bool = Falseoffload_buffers: bool = Falsekeep_in_fp32_modules: typing.List[str] = Noneoffload_8bit_bnb: bool = False )
Parameters
model (
torch.nn.Module
) โ The model in which we want to load a checkpoint.checkpoint (
str
oros.PathLike
) โ The folder checkpoint to load. It can be:a path to a file containing a whole model state dict
a path to a
.json
file containing the index to a sharded checkpointa path to a folder containing a unique
.index.json
file and the shards of a checkpoint.a path to a folder containing a unique pytorch_model.bin or a model.safetensors file.
device_map (
Dict[str, Union[int, str, torch.device]]
, optional) โ A map that specifies where each submodule should go. It doesnโt need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device.offload_folder (
str
oros.PathLike
, optional) โ If thedevice_map
contains any value"disk"
, the folder where we will offload weights.dtype (
str
ortorch.dtype
, optional) โ If provided, the weights will be converted to that type when loaded.offload_state_dict (
bool
, optional, defaults toFalse
) โ IfTrue
, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard does not fit.offload_buffers (
bool
, optional, defaults toFalse
) โ Whether or not to include the buffers in the weights offloaded to disk.keep_in_fp32_modules(
List[str]
, optional) โ A list of the modules that we keep intorch.float32
dtype.offload_8bit_bnb (
bool
, optional) โ Whether or not to enable offload of 8-bit modules on cpu/disk.
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are loaded.
Once loaded across devices, you still need to call dispatch_model() on your model to make it able to run. To group the checkpoint loading and dispatch in one single call, use load_checkpoint_and_dispatch().
accelerate.infer_auto_device_map
( model: Modulemax_memory: typing.Union[typing.Dict[typing.Union[int, str], typing.Union[int, str]], NoneType] = Noneno_split_module_classes: typing.Optional[typing.List[str]] = Nonedtype: typing.Union[str, torch.dtype, NoneType] = Nonespecial_dtypes: typing.Union[typing.Dict[str, typing.Union[str, torch.dtype]], NoneType] = Noneverbose: bool = False )
Parameters
model (
torch.nn.Module
) โ The model to analyze.max_memory (
Dict
, optional) โ A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.no_split_module_classes (
List[str]
, optional) โ A list of layer class names that should never be split across device (for instance any layer that has a residual connection).dtype (
str
ortorch.dtype
, optional) โ If provided, the weights will be converted to that type when loaded.special_dtypes (
Dict[str, Union[str, torch.device]]
, optional) โ If provided, special dtypes to consider for some specific weights (will override dtype used as default for all weights).verbose (
bool
, optional, defaults toFalse
) โ Whether or not to provide debugging statements as the function builds the device_map.
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, such that:
we donโt exceed the memory available of any of the GPU.
if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that has the largest size.
if offload to the CPU is needed,we donโt exceed the RAM available on the CPU.
if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk that has the largest size.
All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the meta device (as it would if initialized within the init_empty_weights
context manager).
Model Hooks
Hook Classes
class accelerate.hooks.ModelHook
( )
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference with PyTorch existing hooks is that they get passed along the kwargs.
Class attribute:
no_grad (
bool
, optional, defaults toFalse
) โ Whether or not to execute the actual forward pass under thetorch.no_grad()
context manager.
detach_hook
( module )
Parameters
module (
torch.nn.Module
) โ The module detached from this hook.
To be executed when the hook is detached from a module.
init_hook
( module )
Parameters
module (
torch.nn.Module
) โ The module attached to this hook.
To be executed when the hook is attached to the module.
post_forward
( moduleoutput ) โ Any
Parameters
module (
torch.nn.Module
) โ The module whose forward pass been executed just before this event.output (
Any
) โ The output of the module.
Returns
Any
The processed output
.
To be executed just after the forward method of the model.
pre_forward
( module*args**kwargs ) โ Tuple[Tuple[Any], Dict[Str, Any]]
Parameters
module (
torch.nn.Module
) โ The module whose forward pass will be executed just after this event.args (
Tuple[Any]
) โ The positional arguments passed to the module.kwargs (
Dict[Str, Any]
) โ The keyword arguments passed to the module.
Returns
Tuple[Tuple[Any], Dict[Str, Any]]
A tuple with the treated args
and kwargs
.
To be executed just before the forward method of the model.
class accelerate.hooks.AlignDevicesHook
( execution_device: typing.Union[int, str, torch.device, NoneType] = Noneoffload: bool = Falseio_same_device: bool = Falseweights_map: typing.Optional[typing.Mapping] = Noneoffload_buffers: bool = Falseplace_submodules: bool = Falseskip_keys: typing.Union[str, typing.List[str], NoneType] = None )
Parameters
execution_device (
torch.device
, optional) โ The device on which inputs and model weights should be placed before the forward pass.offload (
bool
, optional, defaults toFalse
) โ Whether or not the weights should be offloaded after the forward pass.io_same_device (
bool
, optional, defaults toFalse
) โ Whether or not the output should be placed on the same device as the input was.weights_map (
Mapping[str, torch.Tensor]
, optional) โ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.offload_buffers (
bool
, optional, defaults toFalse
) โ Whether or not to include the associated moduleโs buffers when offloading.place_submodules (
bool
, optional, defaults toFalse
) โ Whether to place the submodules onexecution_device
during theinit_hook
event.
A generic ModelHook
that ensures inputs and model weights are on the same device for the forward pass of the associated module, potentially offloading the weights after the forward pass.
class accelerate.hooks.SequentialHook
( *hooks )
A hook that can contain several hooks and iterates through them at each event.
Adding Hooks
accelerate.hooks.add_hook_to_module
( module: Modulehook: ModelHookappend: bool = False ) โ torch.nn.Module
Parameters
module (
torch.nn.Module
) โ The module to attach a hook to.hook (
ModelHook
) โ The hook to attach.append (
bool
, optional, defaults toFalse
) โ Whether the hook should be chained with an existing one (if module already contains a hook) or not.
Returns
torch.nn.Module
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
Adds a hook to a given module. This will rewrite the forward
method of the module to include the hook, to remove this behavior and restore the original forward
method, use remove_hook_from_module
.
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks together, pass append=True
, so it chains the current and new hook into an instance of the SequentialHook
class.
accelerate.hooks.attach_execution_device_hook
( module: Moduleexecution_device: typing.Union[int, str, torch.device]skip_keys: typing.Union[str, typing.List[str], NoneType] = Nonepreload_module_classes: typing.Optional[typing.List[str]] = None )
Parameters
module (
torch.nn.Module
) โ The module where we want to attach the hooks.execution_device (
int
,str
ortorch.device
) โ The device on which inputs and model weights should be placed before the forward pass.skip_keys (
str
orList[str]
, optional) โ A list of keys to ignore when moving inputs or outputs between devices.preload_module_classes (
List[str]
, optional) โ A list of classes whose instances should load all their weights (even in the submodules) at the beginning of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if adense
linear layer is registered, but at forward,dense.weight
anddense.bias
are used in some operations instead of callingdense
directly.
Recursively attaches AlignDevicesHook
to all submodules of a given model to make sure they have the right execution device
accelerate.hooks.attach_align_device_hook
( module: Moduleexecution_device: typing.Optional[torch.device] = Noneoffload: bool = Falseweights_map: typing.Optional[typing.Mapping] = Noneoffload_buffers: bool = Falsemodule_name: str = ''skip_keys: typing.Union[str, typing.List[str], NoneType] = Nonepreload_module_classes: typing.Optional[typing.List[str]] = None )
Parameters
module (
torch.nn.Module
) โ The module where we want to attach the hooks.execution_device (
torch.device
, optional) โ The device on which inputs and model weights should be placed before the forward pass.offload (
bool
, optional, defaults toFalse
) โ Whether or not the weights should be offloaded after the forward pass.weights_map (
Mapping[str, torch.Tensor]
, optional) โ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.offload_buffers (
bool
, optional, defaults toFalse
) โ Whether or not to include the associated moduleโs buffers when offloading.module_name (
str
, optional, defaults to""
) โ The name of the module.skip_keys (
str
orList[str]
, optional) โ A list of keys to ignore when moving inputs or outputs between devices.preload_module_classes (
List[str]
, optional) โ A list of classes whose instances should load all their weights (even in the submodules) at the beginning of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if adense
linear layer is registered, but at forward,dense.weight
anddense.bias
are used in some operations instead of callingdense
directly.
Recursively attaches AlignDevicesHook
to all submodules of a given model that have direct parameters and/or buffers.
accelerate.hooks.attach_align_device_hook_on_blocks
( module: Moduleexecution_device: typing.Union[torch.device, typing.Dict[str, torch.device], NoneType] = Noneoffload: typing.Union[bool, typing.Dict[str, bool]] = Falseweights_map: typing.Mapping = Noneoffload_buffers: bool = Falsemodule_name: str = ''skip_keys: typing.Union[str, typing.List[str], NoneType] = Nonepreload_module_classes: typing.Optional[typing.List[str]] = None )
Parameters
module (
torch.nn.Module
) โ The module where we want to attach the hooks.execution_device (
torch.device
orDict[str, torch.device]
, optional) โ The device on which inputs and model weights should be placed before the forward pass. It can be one device for the whole module, or a dictionary mapping module name to device.offload (
bool
, optional, defaults toFalse
) โ Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole module, or a dictionary mapping module name to boolean.weights_map (
Mapping[str, torch.Tensor]
, optional) โ When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.offload_buffers (
bool
, optional, defaults toFalse
) โ Whether or not to include the associated moduleโs buffers when offloading.module_name (
str
, optional, defaults to""
) โ The name of the module.skip_keys (
str
orList[str]
, optional) โ A list of keys to ignore when moving inputs or outputs between devices.preload_module_classes (
List[str]
, optional) โ A list of classes whose instances should load all their weights (even in the submodules) at the beginning of the forward. This should only be used for classes that have submodules which are registered but not called directly during the forward, for instance if adense
linear layer is registered, but at forward,dense.weight
anddense.bias
are used in some operations instead of callingdense
directly.
Attaches AlignDevicesHook
to all blocks of a given model as needed.
Removing Hooks
accelerate.hooks.remove_hook_from_module
( module: Modulerecurse = False ) โ torch.nn.Module
Parameters
module (
torch.nn.Module
) โ The module to attach a hook to.recurse (
bool
, optional) โ Whether to remove the hooks recursively
Returns
torch.nn.Module
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
Removes any hook attached to a module via add_hook_to_module
.
accelerate.hooks.remove_hook_from_submodules
( module: Module )
Parameters
module (
torch.nn.Module
) โ The module on which to remove all hooks.
Recursively removes all hooks attached on the submodules of a given model.
Last updated