I am trying to LoRA-finetune mistralai/Mistral-7B-v0.1 on the C4 dataset using an NVIDIA A40 GPU with 40GB of memory. The problem is that PyTorch raises an Out-of-Memory error, however, this happens after 3 successful training iterations, despite nothing else happening - no additional evaluation. I would expect the memory demands to remain constant over some iterations. Am I doing something wrong or is that a potential memory leak issue?
Regarding my setup:
I am using batch size 1, AdamW, gradient_checkpointing and accumulation steps of 1.