Fully Sharded Data Parallelism Utilities

Utilities for Fully Sharded Data Parallelism

class accelerate.FullyShardedDataParallelPlugin

<source>

( sharding_strategy: typing.Any = Nonebackward_prefetch: typing.Any = Nonemixed_precision_policy: typing.Any = Noneauto_wrap_policy: typing.Optional[typing.Callable] = Nonecpu_offload: typing.Any = Noneignored_modules: typing.Optional[typing.Iterable[torch.nn.modules.module.Module]] = Nonestate_dict_type: typing.Any = Nonestate_dict_config: typing.Any = Noneoptim_state_dict_config: typing.Any = Nonelimit_all_gathers: bool = Falseuse_orig_params: bool = Falseparam_init_fn: typing.Optional[typing.Callable[[torch.nn.modules.module.Module]], NoneType] = Nonesync_module_states: bool = Trueforward_prefetch: bool = Falseactivation_checkpointing: bool = False )

This plugin is used to enable fully sharded data parallelism.

get_module_class_from_name

<source>

( modulename )

Parameters

  • module (torch.nn.Module) โ€” The module to get the class from.

  • name (str) โ€” The name of the class.

Gets a class from a module by its name.

Last updated