Fine tune Meta-Llama-3.1-8B OOM error after the 1st training step

Tried to fine tune llama 3.1 8B on a single GPU with 49G memory.
The training can run for one step, calculated the loss, but gave OOM error on the 2nd step.

What could be the possible reason?

Environment:

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

- `transformers` version: 4.44.2
- Platform: Linux-5.15.0-73-generic-x86_64-with-glibc2.31
- Python version: 3.11.5
- Huggingface_hub version: 0.24.6
- Safetensors version: 0.4.4
- Accelerate version: 0.34.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.4.1+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA RTX A6000

code snippet:

def print_gpu_memory(step=""):
    print(f"\n--- GPU Memory Usage at {step} ---")
    print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
    print(f"Free: {torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated() / 1e9:.2f} GB")

# Define the quantization config
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

base_model = LlamaForCausalLM.from_pretrained(
    llama_model_path,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
)
print_gpu_memory("after base model")


base_model.resize_token_embeddings(len(tokz))


lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"]
)


model = get_peft_model(base_model, lora_config)


print_gpu_memory("after peft")


output_dir = "xxx"


training_args = TrainingArguments(
    optim="sgd",
    learning_rate=0.01,
    output_dir=output_dir,
    overwrite_output_dir=True,
    logging_dir=f"{output_dir}/logs",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    logging_strategy="steps",
    logging_steps=10,
    save_strategy="no",
    num_train_epochs=10,
    gradient_accumulation_steps=4,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    bf16=True,
    gradient_checkpointing=True,
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokz,
    mlm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=data_collator,
)


print_gpu_memory("After init trainer")

try:
    for epoch in range(training_args.num_train_epochs):
        print(f"\nStarting epoch {epoch + 1}")
        print_gpu_memory(f"Start of epoch {epoch + 1}")

        for step, batch in enumerate(trainer.get_train_dataloader()):
            print(f"Step {step + 1}")
            print_gpu_memory(f"Before step {step + 1}")

            loss = trainer.training_step(model, batch)

            print(f"Loss: {loss.item()}")
            print_gpu_memory(f"After step {step + 1}")

            if (step + 1) % 10 == 0:  # Print every 10 steps
                print(f"Completed {step + 1} steps")
                print_gpu_memory(f"After {step + 1} steps")

            if (step + 1) % 100 == 0:
                torch.cuda.empty_cache()

    print_gpu_memory(f"End of epoch {epoch + 1}")

except RuntimeError as e:
    if "out of memory" in str(e):
        print(f"WARNING: out of memory")
        logging.exception(e)
        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()
    else:
        raise e

print_gpu_memory("After training")

Error log:

Starting epoch 1                                                                                                                                                                                              
                                                                                                                                                                                                              
--- GPU Memory Usage at Start of epoch 1 ---                                                                                                                                                                  
Allocated: 10.16 GB                                                                                                                                                                                           
Reserved: 10.34 GB                                                                                                                                                                                            
Free: 40.88 GB                                                                                                                                                                                                
Step 1                                                                                                                                                                                                        
                                                                                                                                                                                                              
--- GPU Memory Usage at Before step 1 ---                                                                                                                                                                     
Allocated: 10.16 GB                                                                                                                                                                                           
Reserved: 10.34 GB                                                                                                                                                                                            
Free: 40.88 GB                                                                                                                                                                                                
/home/yunding/anaconda3/envs/mapgpt-dp/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")                                                                                                           
Loss: 1.0950329303741455                                                                                                                                                                                      
                                                                                                                                                                                                              
--- GPU Memory Usage at After step 1 ---                                                                                                                                                                      
Allocated: 9.16 GB                                                                                                                                                                                            
Reserved: 24.75 GB                                                                                                                                                                                            
Free: 41.88 GB                                                                                                                                                                                                
Step 2                                                                                                                                                                                                        
                                                                                                                                                                                                              
--- GPU Memory Usage at Before step 2 ---                                                                                                                                                                     
Allocated: 9.16 GB                                                                                                                                                                                            
Reserved: 24.75 GB                                                                                                                                                                                            
Free: 41.88 GB                                                                                                                                                                                                
WARNING: out of memory                                                                                                                                                                                        
ERROR:root:CUDA out of memory. Tried to allocate 100.00 MiB. GPU 0 has a total capacity of 47.53 GiB of which 40.25 MiB is free. Including non-PyTorch memory, this process has 47.32 GiB memory in use. Of th
e allocated memory 45.17 GiB is allocated by PyTorch, and 1.81 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments
:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)