Memory consumption qlora with gradient checkpointing

Hi !
I’ve used qlora with gradient checkpointing on llama-2-7b and I’m surprised by the huge quantity of VRAM it’s taking when calling forward on a 2577 tokens: before the forward pass, it was using only 4.8GB, and during forward, crashed with OOM on a A100 80GB !

I used gradient checkpointing with:

model = AutoModelForCausalLM.from_pretrained(wd, load_in_4bit = True)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = True)
model.gradient_checkpointing_enable()
peft_config = LoraConfig(r = 8, task_type = TaskType.CAUSAL_LM)
model = get_peft_model(model, peft_config)

It shouldn’t use that much VRAM with gradient checkpointing, should it?

Thanks,
Christophe