I figured to use multi-GPU by changing a few settings like device_map
and also used notebook_launcher
to use accelerate
capability in Kaggle notebook. However, I got OOM error for fine-tuning 4-bit quantized Llama3-8B on 2 T4 GPUs. I’d think for 4-bit quantized FT of 8B parms, 16GB of 1 GPU is sufficient and hence 2 GPUs with distributed training should not give OOM error. I noticed the GPU usage was 11.5 GB (screenshot given) on each of the 2 GPUs right after loading checkpoints of the model which seems strange and the FT trainer failed soon after.
Notebook function code:
def main():
from transformers import BitsAndBytesConfig
from trl import ORPOConfig, ORPOTrainer, setup_chat_format
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from accelerate import Accelerator
accelerator = Accelerator()
device_map = {"": accelerator.process_index}
# QLoRA config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# LoRA config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "Llama-3-8B_FT_ORPO"
tokenizer = AutoTokenizer.from_pretrained(base_model, token=HF_TOKEN)
# Load model
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
# device_map="auto",
device_map=device_map,
token=HF_TOKEN,
attn_implementation="eager"
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)
dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(30)) # Only use 30 samples for test
def format_chat_template(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
dataset = dataset.map(
format_chat_template,
num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.1)
# torch.cuda.empty_cache()
orpo_args = ORPOConfig(
learning_rate=8e-6,
lr_scheduler_type="linear",
max_length=1024,
max_prompt_length=512,
beta=0.1,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=4,
optim="paged_adamw_8bit",
num_train_epochs=1,
evaluation_strategy="steps",
eval_steps=0.2,
logging_steps=1,
warmup_steps=10,
report_to="none",
output_dir="./results/",
remove_unused_columns=False,
fp16=True,
bf16=False,
ddp_find_unused_parameters=False,
gradient_checkpointing=True,
# gradient_checkpointing_kwargs = {"use_reentrant": False}, #must be false for DDP
)
trainer = ORPOTrainer(
model=model,
args=orpo_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
peft_config=peft_config,
tokenizer=tokenizer,
)
print(f'n_gpu: {orpo_args.n_gpu}; Mode: {orpo_args.parallel_mode}')
print(f'Num Processes: {accelerator.num_processes}; Device: {accelerator.device}; Process Index: {accelerator.process_index}')
print(f'Accel Type: {accelerator.distributed_type}')
trainer.train()
trainer.save_model(new_model)
notebook_launcher(main, num_processes=2)
After the OOM error, I tried to see if FSDP can be used by adding following 2 arguments to the ORPOConfig
but it resulted in an AttributeError
fsdp="full_shard",
fsdp_config={'min_num_params': 2000, 'offload_params': False, 'sharding_strategy': 1},
Error:
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
fn(i, *args)
File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/launch.py", line 626, in __call__
self.launcher(*args)
File "/tmp/ipykernel_34/1091546508.py", line 122, in main
trainer.train()
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2001, in _inner_training_loop
self._fsdp_qlora_plugin_updates()
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 4425, in _fsdp_qlora_plugin_updates
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
File "/opt/conda/lib/python3.10/site-packages/peft/utils/other.py", line 396, in fsdp_auto_wrap_policy
transformer_cls = FullyShardedDataParallelPlugin.get_module_class_from_name(model, layer_class)
AttributeError: type object 'FullyShardedDataParallelPlugin' has no attribute 'get_module_class_from_name'
Questions:
- How much RAM should be enough to FT a 4-bit quantized 8B on single / multi-GPU environment?
- How to force
accelerate
to use FSDP
when running in a notebook environment where config
file is not used
- Is the Error highlighted above due to any incorrect argument in the code or anything else missing?
@muellerzr - Appreciate your thoughts!