Mistral-7B-v0.1 finetuning results in Out-of-Memory after some iterations

I am trying to LoRA-finetune mistralai/Mistral-7B-v0.1 on the C4 dataset using an NVIDIA A40 GPU with 40GB of memory. The problem is that PyTorch raises an Out-of-Memory error, however, this happens after 3 successful training iterations, despite nothing else happening - no additional evaluation. I would expect the memory demands to remain constant over some iterations. Am I doing something wrong or is that a potential memory leak issue?

Regarding my setup:
I am using batch size 1, AdamW, gradient_checkpointing and accumulation steps of 1.


Could you share your code here for better understanding of the problem?

I suspect you could be caching the dataset of every batch in GPU which build after a few iterations. Check the dataloader and make sure you flush out used batches after some steps.