Fine Tuning LLama 3.2 1B Quantized Memory Requirements

Hi All!
I’m trying to fine tune a LLama 3.2 1B instruct model, that has been quantized during loading. But for some reason, the trainer errors out stating:
“OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 GiB. GPU 0 has a total capacity of 6.00 GiB of which 3.48 GiB is free.”
I’m not sure why this is the case that training a relatively small model is requiring 32GB of VRAM.
Would really appreciate any help if possible, code attached below:

quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
    )
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", num_labels = 23, quantization_config=quantization_config)
lora_config = LoraConfig(
    r=4,
    lora_alpha=8,
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules = ["q_proj", "k_proj", "v_proj"]
)
model = get_peft_model(model, lora_config)

training_args = TrainingArguments(
    output_dir='./results',          
    num_train_epochs=1,              
    per_device_train_batch_size=1,   
    per_device_eval_batch_size=16,   
    warmup_steps=500,                
    weight_decay=0.01,               
    logging_dir='./logs',            
    logging_steps=10,
    evaluation_strategy="epoch",     
    gradient_accumulation_steps=4,  
)

trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=eval_dataset,
1 Like

With the given information, it’s hard to tell what the reason is for why so much memory is needed, there is nothing obviously wrong there. It would be helpful if you could share the full code and the full error message.

One common source of excessive memory usage can be in the data itself. If you have very long sequences, because of the quadratic memory requirement, OOMs can easily occur. You could check if setting a low max_seq_length helps to curb the memory usage.

1 Like