PagFully Sharded Data Parallele 2
Fully sharded data parallel (FSDP) is developed for distributed training of large pretrained models up to 1T parameters. FSDP achieves this by sharding the model parameters, gradients, and optimizer states across data parallel processes and it can also offload sharded model parameters to a CPU. The memory efficiency afforded by FSDP allows you to scale training to larger batch or model sizes.
Currently, FSDP does not confer any reduction in GPU memory usage and FSDP with CPU offload actually consumes 1.65x more GPU memory during training. You can track this PyTorch issue for any updates.
FSDP is supported in 🌍 Accelerate, and you can use it with 🌍 PEFT. This guide will help you learn how to use our FSDP training script. You’ll configure the script to train a large model for conditional generation.
Configuration
Begin by running the following command to create a FSDP configuration file with 🌍 Accelerate. Use the --config_file
flag to save the configuration file to a specific location, otherwise it is saved as a default_config.yaml
file in the 🌍 Accelerate cache.
The configuration file is used to set the default options when you launch the training script.
Copied
You’ll be asked a few questions about your setup, and configure the following arguments. For this example, make sure you fully shard the model parameters, gradients, optimizer states, leverage the CPU for offloading, and wrap model layers based on the Transformer layer class name.
Copied
For example, your FSDP configuration file may look like the following:
Copied
The important parts
Let’s dig a bit deeper into the training script to understand how it works.
The main()
function begins with initializing an Accelerator class which handles everything for distributed training, such as automatically detecting your training environment.
💡 Feel free to change the model and dataset inside the main
function. If your dataset format is different from the one in the script, you may also need to write your own preprocessing function.
The script also creates a configuration corresponding to the 🌍 PEFT method you’re using. For LoRA, you’ll use LoraConfig to specify the task type, and several other important parameters such as the dimension of the low-rank matrices, the matrices scaling factor, and the dropout probability of the LoRA layers. If you want to use a different 🌍 PEFT method, replace LoraConfig
with the appropriate class.
Next, the script wraps the base model and peft_config
with the get_peft_model()
function to create a PeftModel.
Copied
Throughout the script, you’ll see the main_process_first and wait_for_everyone functions which help control and synchronize when processes are executed.
After your dataset is prepared, and all the necessary training components are loaded, the script checks if you’re using the fsdp_plugin
. PyTorch offers two ways for wrapping model layers in FSDP, automatically or manually. The simplest method is to allow FSDP to automatically recursively wrap model layers without changing any other code. You can choose to wrap the model layers based on the layer name or on the size (number of parameters). In the FSDP configuration file, it uses the TRANSFORMER_BASED_WRAP
option to wrap the T5Block
layer.
Copied
Next, use 🌍 Accelerate’s prepare function to prepare the model, datasets, optimizer, and scheduler for training.
Copied
From here, the remainder of the script handles the training loop, evaluation, and sharing your model to the Hub.
Train
Run the following command to launch the training script. Earlier, you saved the configuration file to fsdp_config.yaml
, so you’ll need to pass the path to the launcher with the --config_file
argument like this:
Copied
Once training is complete, the script returns the accuracy and compares the predictions to the labels.
Last updated