OOM error with multi-GPU training of Llama2-70B using QLora

I am trying to train Llama2-70B model using 4-bit QLora on a 8xA100 80G instance. When using only a single GPU, it runs comfortably - uses < 50G of VRAM with a batch size of 2. I am also setting gradient_accumulation_steps = 4.

But when I run it on 8 GPUs, it consistently OOMs without completing a single step, even with per device batch size = 1. How do I debug this?

Accelerate config: https://gist.github.com/amangup/3a5f80a541d8226ca2101389e8bf1805

Training script: https://gist.github.com/amangup/e49ca9fc042caa062eeb6c1355fcd6c9

I run the script by just running accelerate launch llama70_qlora_multigpu.py

So I did a few more things:

  1. Ran the script on a 7B model, and the training completed. But, the per GPU memory cost was 24-28GB/GPU, compared to < 20GB for single GPU training (with the same batch size).

  2. Changed the precision to fp16 from bf16 (fp16 is the dtype defined in the config.json for the llama2 models), and surprisingly it completed one step, and ran OOM in step 2. Not sure why that would affect memory usage.

  3. I read this explanation of DDP. It looks like the extra memory usage (compared to single GPU execution) is limited to the gradients of a single tensor from all GPUs, which shouldn’t result in so much in extra memory usage that I’m observing!

  4. I assumed that the memory overhead due to DDP is proportional to parallelism. So I ran it only on 4 GPUs. This time it failed on step 5 instead of step 2. So it sounds like the DDP overhead is pretty huge.

I wonder if the size of the reducer bucket is tunable.

So, I finally got this to work. I had to turn gradient checkpointing off to deal with an error I was seeing. Turns out I can tweak a ddp param and it will allow gradient checkpointing to work. That brought down the memory requirement enough for this to succeed on all 8 GPUs.