Saving and loading training states
Checkpointing
When training a PyTorch model with ๐ Accelerate, you may often want to save and continue a state of training. Doing so requires saving and loading the model, optimizer, RNG generators, and the GradScaler. Inside ๐ Accelerate are two convenience functions to achieve this quickly:
Use save_state() for saving everything mentioned above to a folder location
Use load_state() for loading everything stored from an earlier
save_state
To further customize where and how states are saved through save_state() the ProjectConfiguration class can be used. For example if automatic_checkpoint_naming
is enabled each saved checkpoint will be located then at Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}
.
It should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.
By using register_for_checkpointing(), you can register custom objects to be automatically stored or loaded from the two prior functions, so long as the object has a
state_dict
and aload_state_dict
functionality. This could include objects such as a learning rate scheduler.
Below is a brief example using checkpointing to save and reload a state during training:
Copied
Restoring the state of the DataLoader
After resuming from a checkpoint, it may also be desirable to resume from a particular point in the active DataLoader
if the state was saved during the middle of an epoch. You can use skip_first_batches() to do so.
Copied
Last updated