Fully Sharded Data Parallelism Utilities
Utilities for Fully Sharded Data Parallelism
class accelerate.FullyShardedDataParallelPlugin
( sharding_strategy: typing.Any = Nonebackward_prefetch: typing.Any = Nonemixed_precision_policy: typing.Any = Noneauto_wrap_policy: typing.Optional[typing.Callable] = Nonecpu_offload: typing.Any = Noneignored_modules: typing.Optional[typing.Iterable[torch.nn.modules.module.Module]] = Nonestate_dict_type: typing.Any = Nonestate_dict_config: typing.Any = Noneoptim_state_dict_config: typing.Any = Nonelimit_all_gathers: bool = Falseuse_orig_params: bool = Falseparam_init_fn: typing.Optional[typing.Callable[[torch.nn.modules.module.Module]], NoneType] = Nonesync_module_states: bool = Trueforward_prefetch: bool = Falseactivation_checkpointing: bool = False )
This plugin is used to enable fully sharded data parallelism.
get_module_class_from_name
( modulename )
Parameters
module (
torch.nn.Module
) โ The module to get the class from.name (
str
) โ The name of the class.
Gets a class from a module by its name.
Last updated