Multi-GPU Training using SFTTrainer

how do i use the yaml file in my code below, i want to use SFTTrainer i dont want to run it through scripts,is there a way to train my model on multi-gpu. Is there a way to pass the file to the train and automatically train on multi gpu

from datasets import load_dataset

dataset = load_dataset("Alok2304/Indian_Law_Final_Dataset",split="train[:30%]")
dataset

from transformers import BitsAndBytesConfig
import torch

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B",
    quantization_config = bnb_config,
    trust_remote_code = True).to(device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")

from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules= ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
)
model = get_peft_model(model, lora_config)

from trl import setup_chat_format
model,tokenizer = setup_chat_format(model,tokenizer)

from trl import SFTConfig, SFTTrainer

args = SFTConfig(
    output_dir = "lora_model/",
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    learning_rate = 2e-05,
    gradient_accumulation_steps = 2,
    max_steps = 300,
    logging_strategy = "steps",
    logging_steps = 25,
    save_strategy = "steps",
    save_steps = 25,
    eval_strategy = "steps",
    eval_steps = 25,
    fp16 = True,
    data_seed=42,
    max_seq_length = 2048,
    gradient_checkpointing=True,
    report_to = "none",
)

trainer = SFTTrainer(
    model = model,
    args = args,
    processing_class = tokenizer,
    train_dataset = dataset['train'],
    eval_dataset = dataset['test'],)
1 Like