I got OOM error when finetuning phi4-mini-reasoning with a batch size of 1 and length of 1024 tokens with 3 v100 16gb, tried both DDP and FSDP, both got OOM, but when I use single gpu, it pass with a batch size of 2 and peak memory at <15GB, another know whatβs work? the following is my training script
#!/usr/bin/env python
import os, torch, argparse
from datasets import load_from_disk
from transformers import (
AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
TrainingArguments, Trainer
)
from peft import (
LoraConfig, prepare_model_for_kbit_training, get_peft_model
)
from accelerate import DistributedDataParallelKwargs
from accelerate.utils import DistributedType
from functools import partial
import torch.distributed as dist # Added for explicit init
MODEL_ID = "microsoft/Phi-4-mini-reasoning"
DATA_PATH = "./flattened_distilled_dataset-win1k-phi"
OUT_DIR = "./lora_phi4"
def main():
# βββββββββββββββββββββββββββββ Accelerator info ββ
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank)
# Explicitly initialize distributed process group (safer)
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=local_rank
)
# βββββββββββββββββββββββββββββ Tokenizer βββββββββ
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
# βββββββββββββββββββββββββββββ 4-bit model βββββββ
bnb_cfg = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = torch.float16,
bnb_4bit_use_double_quant= True,
)
# each process loads on *its* GPU
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config = bnb_cfg,
device_map = {"": device}, # crucial
trust_remote_code = True,
use_cache = False,
)
base_model = prepare_model_for_kbit_training(base_model)
# βββββββββββββββββββββββββββββ LoRA ββββββββββββββ
lora_cfg = LoraConfig(
r=32, # Reduced from 64 to save memory
lora_alpha=128,
lora_dropout=0.08,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"up_proj", "down_proj", "gate_proj",
"fc1", "fc2", "dense"],
)
model = get_peft_model(base_model, lora_cfg)
try:
import xformers, xformers.ops
model._set_memory_efficient_attention_xformers(True)
except Exception:
pass
model.print_trainable_parameters()
# Print GPU memory usage after model load (for debugging)
if dist.get_rank() == local_rank:
print(f"Rank {local_rank}: GPU memory allocated after model load: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB")
print(f"Rank {local_rank}: GPU memory reserved after model load: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB")
# βββββββββββββββββββββββββββββ Dataset βββββββββββ
# Load full dataset, then shard per process to save RAM
full_ds = load_from_disk(DATA_PATH)
train_ds = full_ds.shard(num_shards=world_size, index=local_rank)
if local_rank == 0:
print(f"Sharded dataset: each process has {len(train_ds)} examples (total: {len(full_ds)})")
def collate_fn(batch):
ids = [torch.tensor(x["input_ids"]) for x in batch]
mask = [torch.tensor(x["attention_mask"]) for x in batch]
return {
"input_ids" : torch.stack(ids),
"attention_mask": torch.stack(mask),
"labels" : torch.stack(ids),
}
# βββββββββββββββββββββββββββββ TrainingArguments β
args = TrainingArguments(
output_dir = OUT_DIR,
num_train_epochs = 2,
per_device_train_batch_size= 1, # per-GPU!
gradient_accumulation_steps= 2, # effective * world_size
gradient_checkpointing = True,
learning_rate = 3e-4,
warmup_ratio = 0.03,
logging_steps = 10,
save_steps = 500,
save_strategy = "steps",
fp16 = True,
optim = "paged_adamw_8bit", # Changed to paged BNB optimizer for memory efficiency
report_to = "none",
remove_unused_columns = False,
ddp_find_unused_parameters = False, # speeds up DDP
ddp_bucket_cap_mb = 25, # Balanced value
)
trainer = Trainer(
model = model,
args = args,
train_dataset = train_ds,
tokenizer = tok,
data_collator = collate_fn,
)
# Print GPU memory after Trainer init (before train)
if dist.get_rank() == local_rank:
print(f"Rank {local_rank}: GPU memory allocated after Trainer init: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB")
print(f"Rank {local_rank}: GPU memory reserved after Trainer init: {torch.cuda.memory_reserved(device) / 1e9:.2f} GB")
trainer.train()
# Only rank-0 writes to disk
if local_rank == 0:
trainer.save_model(OUT_DIR)
tok.save_pretrained(OUT_DIR)
if __name__ == "__main__":
main()