Kwargs handlers
Kwargs Handlers
The following objects can be passed to the main Accelerator to customize how some PyTorch objects related to distributed training or mixed precision are created.
AutocastKwargs
class accelerate.AutocastKwargs
( enabled: bool = Truecache_enabled: bool = None )
Use this object in your Accelerator to customize how torch.autocast
behaves. Please refer to the documentation of this context manager for more information on each argument.
Example:
Copied
DistributedDataParallelKwargs
class accelerate.DistributedDataParallelKwargs
( dim: int = 0broadcast_buffers: bool = Truebucket_cap_mb: int = 25find_unused_parameters: bool = Falsecheck_reduction: bool = Falsegradient_as_bucket_view: bool = Falsestatic_graph: bool = False )
Use this object in your Accelerator to customize how your model is wrapped in a torch.nn.parallel.DistributedDataParallel
. Please refer to the documentation of this wrapper for more information on each argument.
gradient_as_bucket_view
is only available in PyTorch 1.7.0 and later versions.
static_graph
is only available in PyTorch 1.11.0 and later versions.
Example:
Copied
FP8RecipeKwargs
class accelerate.utils.FP8RecipeKwargs
( margin: int = 0interval: int = 1fp8_format: str = 'E4M3'amax_history_len: int = 1amax_compute_algo: str = 'most_recent'override_linear_precision: typing.Tuple[bool, bool, bool] = (False, False, False) )
Use this object in your Accelerator to customize the initialization of the recipe for FP8 mixed precision training. Please refer to the documentation of this class for more information on each argument.
Copied
GradScalerKwargs
class accelerate.GradScalerKwargs
( init_scale: float = 65536.0growth_factor: float = 2.0backoff_factor: float = 0.5growth_interval: int = 2000enabled: bool = True )
Use this object in your Accelerator to customize the behavior of mixed precision, specifically how the torch.cuda.amp.GradScaler
used is created. Please refer to the documentation of this scaler for more information on each argument.
GradScaler
is only available in PyTorch 1.5.0 and later versions.
Example:
Copied
InitProcessGroupKwargs
class accelerate.InitProcessGroupKwargs
( backend: typing.Optional[str] = 'nccl'init_method: typing.Optional[str] = Nonetimeout: timedelta = datetime.timedelta(seconds=1800) )
Use this object in your Accelerator to customize the initialization of the distributed processes. Please refer to the documentation of this method for more information on each argument.
Copied
Last updated