SFTTrainer takes up so much ram that it breaks an A100 GPU

Hello folks,

I have been trying to fine-tune Llama 3 with VeRA adapter on a quite small dataset, which is "mlabonne/guanaco-llama2-1k". I put my training configs in a SFTConfig and initiated a SFTTrainer object as my trainer. This is where the code blows up! It couldn’t even pass to trainer.train() stage since the trainer object (initiated with SFTTrainer ) took so much RAM (not GPU RAM) that it caused both A100 and T4 to throw OOM.

You can find my Google Colab notebook here. I have tried

  • lots of different combination with data types,
  • datasets with different sizes,
  • different quantization and attention implementation configs,
  • using TrainingArguments instead of SFTConfig,
  • different VeRA configs

but nothing worked. I would be more than glad if you take a look at it and give me a glimmer of hope. Thanks!