FSDP accelerate.prepare gives OOM. How to load model into single GPU, then distribute shards?

Hi.
I am trying to train a model with FSDP, and currently getting OOM. This is the situation/setup:

  • 4 A100-40GB gpus.
  • Model size: 35GB (it fits into a single gpu, taking 88% of memory).

I have realised that, when I run the code, the 4 GPUs will load the same model into memory, having 4 replicas of the model, one in each gpu. Then, the error is thrown in accelerate.prepare.

My understanding was that with FSDP I could create shards of the model and thus, avoid having the entire model in a single GPU.

So the question is:
How can I load the model into a single GPU, and then distribute it among the rest of the GPUs?

Thanks a lot in advance!


Launch script with command:

accelerate launch --config_file config/fsdp_config.yaml --mixed_precision="bf16" --num_processes=$NUM_GPUS train_uvit_accelerate.py \
  --epochs=$EPOCHS \
  --batch_size=$BATCH_SIZE \
  --img_size=$IMAGE_SIZE

Code snippet:

    accelerator = Accelerator(log_with="wandb")
    accelerator.init_trackers(
        project_name=PROJECT_NAME,
        config=config,
        init_kwargs={"wandb": {"tags": ["sd", config.model_name]}},
    )
    device = accelerator.device
    
    # model setup
    model = UViT(**config.model_params).to(device)
    model = accelerator.prepare(model)

Accelerate.config yml file:

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_offload_params: true
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: Transformer, UViT
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

@muellerzr Could you shed some light on the matter? :slight_smile: Maybe I am not understanding some concept correctly.
Thanks!

This probably happens because you are moving the model to each gpu here. Could you try by removing this line ?