OOM error for lora multigpu finetuning

I got OOM error when finetuning phi4-mini-reasoning with a batch size of 1 and length of 1024 tokens with 3 v100 16gb, tried both DDP and FSDP, both got OOM, but when I use single gpu, it pass with a batch size of 2 and peak memory at <15GB, another know what’s work? the following is my training script

#!/usr/bin/env python
import os, torch, argparse
from datasets import load_from_disk
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    TrainingArguments, Trainer
)
from peft import (
    LoraConfig, prepare_model_for_kbit_training, get_peft_model
)
from accelerate import DistributedDataParallelKwargs
from accelerate.utils import DistributedType
from functools import partial
import torch.distributed as dist  # Added for explicit init

MODEL_ID  = "microsoft/Phi-4-mini-reasoning"
DATA_PATH = "./flattened_distilled_dataset-win1k-phi"
OUT_DIR   = "./lora_phi4"

def main():
    # ───────────────────────────── Accelerator info ──
    local_rank  = int(os.environ.get("LOCAL_RANK", 0))
    world_size  = int(os.environ.get("WORLD_SIZE", 1))
    device      = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(local_rank)
    
    # Explicitly initialize distributed process group (safer)
    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=local_rank
    )
    
    # ───────────────────────────── Tokenizer ─────────
    tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

    # ───────────────────────────── 4-bit model ───────
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit             = True,
        bnb_4bit_quant_type      = "nf4",
        bnb_4bit_compute_dtype   = torch.float16,
        bnb_4bit_use_double_quant= True,
    )

    # each process loads on *its* GPU
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config = bnb_cfg,
        device_map          = {"": device},     # crucial
        trust_remote_code   = True,
        use_cache           = False,
    )
    base_model = prepare_model_for_kbit_training(base_model)

    # ───────────────────────────── LoRA ──────────────
    lora_cfg = LoraConfig(
        r=32,  # Reduced from 64 to save memory
        lora_alpha=128,
        lora_dropout=0.08,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
            "up_proj", "down_proj", "gate_proj",
            "fc1", "fc2", "dense"],
    )
    model = get_peft_model(base_model, lora_cfg)

    try:
        import xformers, xformers.ops
        model._set_memory_efficient_attention_xformers(True)
    except Exception:
        pass

    model.print_trainable_parameters()
    
    # Print GPU memory usage after model load (for debugging)
    if dist.get_rank() == local_rank:
        print(f"Rank {local_rank}: GPU memory allocated after model load: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB")
        print(f"Rank {local_rank}: GPU memory reserved after model load: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB")

    # ───────────────────────────── Dataset ───────────
    # Load full dataset, then shard per process to save RAM
    full_ds = load_from_disk(DATA_PATH)
    train_ds = full_ds.shard(num_shards=world_size, index=local_rank)
    if local_rank == 0:
        print(f"Sharded dataset: each process has {len(train_ds)} examples (total: {len(full_ds)})")

    def collate_fn(batch):
        ids  = [torch.tensor(x["input_ids"]) for x in batch]
        mask = [torch.tensor(x["attention_mask"]) for x in batch]
        return {
            "input_ids"     : torch.stack(ids),
            "attention_mask": torch.stack(mask),
            "labels"        : torch.stack(ids),
        }

    # ───────────────────────────── TrainingArguments ─
    args = TrainingArguments(
        output_dir                 = OUT_DIR,
        num_train_epochs           = 2,
        per_device_train_batch_size= 1,      # per-GPU!
        gradient_accumulation_steps= 2,      # effective * world_size
        gradient_checkpointing     = True,
        learning_rate              = 3e-4,
        warmup_ratio               = 0.03,
        logging_steps              = 10,
        save_steps                 = 500,
        save_strategy              = "steps",
        fp16                       = True,
        optim                      = "paged_adamw_8bit",  # Changed to paged BNB optimizer for memory efficiency
        report_to                  = "none",
        remove_unused_columns      = False,
        ddp_find_unused_parameters = False,  # speeds up DDP
        ddp_bucket_cap_mb          = 25,     # Balanced value
    )

    trainer = Trainer(
        model         = model,
        args          = args,
        train_dataset = train_ds,
        tokenizer     = tok,
        data_collator = collate_fn,
    )
    
    # Print GPU memory after Trainer init (before train)
    if dist.get_rank() == local_rank:
        print(f"Rank {local_rank}: GPU memory allocated after Trainer init: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB")
        print(f"Rank {local_rank}: GPU memory reserved after Trainer init: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB")

    trainer.train()
    # Only rank-0 writes to disk
    if local_rank == 0:
        trainer.save_model(OUT_DIR)
        tok.save_pretrained(OUT_DIR)

if __name__ == "__main__":
    main()
1 Like