Saving and loading training states
Last updated
Last updated
When training a PyTorch model with ๐ Accelerate, you may often want to save and continue a state of training. Doing so requires saving and loading the model, optimizer, RNG generators, and the GradScaler. Inside ๐ Accelerate are two convenience functions to achieve this quickly:
Use for saving everything mentioned above to a folder location
Use for loading everything stored from an earlier save_state
To further customize where and how states are saved through the class can be used. For example if automatic_checkpoint_naming
is enabled each saved checkpoint will be located then at Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}
.
It should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.
By using , you can register custom objects to be automatically stored or loaded from the two prior functions, so long as the object has a state_dict
and a load_state_dict
functionality. This could include objects such as a learning rate scheduler.
Below is a brief example using checkpointing to save and reload a state during training:
Copied
Copied
After resuming from a checkpoint, it may also be desirable to resume from a particular point in the active DataLoader
if the state was saved during the middle of an epoch. You can use to do so.