Gradient Checkpointing with FSDP efficiency

I have been fine-tuning a llama-style model on 8 40GB A100 GPUs with flash attention and FSDP. I was trying training argument combinations and found that turning off gradient checkpointing actually slowed the training throughput which was very surprising. I was curious about why this might happen since gradient checkpointing requires re-computations! (I suspect this might be caused by FSDP since for smaller models turning gradient checkpointing off improves throughput)