FSDP OOM issue and comparision to DeepSpeed

I am trying to fine-tune the EleutherAI/gpt-j-6b model on my dataset. I am using run_clm.py script from the transformers library for fine-tuning them models.

My system configuration is

- `Accelerate` version: 0.20.3
- Platform: Linux-4.19.0-22-cloud-amd64-x86_64-with-glibc2.28
- Python version: 3.9.16
- Numpy version: 1.25.0
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- PyTorch XPU available: False
- System RAM: 669.27 GB
- GPU type: NVIDIA A100-SXM4-40GB
- Transformers version: 4.31.0.dev0

The command I use to run is

torchrun --nproc_per_node=8 run_clm.py --model_name_or_path EleutherAI/gpt-j-6b --per_device_train_batch_size 1 --output_dir /tmp/tmpiii --overwrite_output_dir --bf16 --do_train --max_train_samples 64 --num_train_epochs 1 --dataset_name "wikitext" --dataset_config "wikitext-2-v1" --save_strategy "no" --report_to none --block_size [BLOCK_SIZE]  --gradient_checkpointing True   --fsdp "full_shard auto_wrap" --fsdp_transformer_layer_cls_to_wrap "GPTJBlock" 

where BLOCK_SIZE = 1024 or 2048

The above script works for 1024 block size but get OOM error for 2048. I can get it to work using deepspeed

I have 2 questions:

  • Why am I getting OOM with FSDP and not with DeepSpeed? Is it because of absent CPU offloading in FSDP + mixed precision?
  • What is difference in performance when using FSDP vs DeepSpeed? Will the answer depend on block_size, model_size, and even the underlying GPU?

That seems probable as CPU offloading reduces VRAM usage significantly. The other being that FSDP checkpointing is different and is currently not supported by --gradient_checkpointing argument. As mentioned in the issue, please refer this notebook now how to do it: https://github.com/lessw2020/transformer_central/tree/main/activation_checkpointing_tutorial.

Ideally, they should be similar. I haven’t done such benchmarking yet.

FSDP checkpointing is different and is currently not supported by --gradient_checkpointing argument. As mentioned in the issue, please refer this notebook now how to do it:

Ah I see, I thought trainer would internally handle the fsdp checkpointing. Thanks for letting me know!

Is this implying that having gradient_checkpointing = True was the issue, and can be fixed by changing it to False?