I am working with a scenario where I need to perform fine-tuning for long-context models. I am specifically interested in optimizing GPU usage for single-GPU long-context training. Currently, I manage to get the training to run at a tokenization length of 8192 by juggling around a few parameters. Ideally, I would like to double or even quadruple that length, because I believe the context windows for the Gemma3 models are at least 32K. Also, I believe doubling the length is possible, because the GPU usage for length=8192 is around 40GB, which is almost exactly half of one A100. However, when I set length=16384, I get CUDA OOM
. What are some avenues I can explore to optimize GPU usage, with the obvious two being (1) more GPUs (2) quantizing the model?
from datasets import load_dataset
from trl import RewardTrainer, RewardConfig
from peft import LoraConfig, TaskType
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.set_default_device('cuda')
model = AutoModelForCausalLM.from_pretrained("gemma3", attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained("gemma3")
train_dataset = load_dataset("json", data_files="training_data.json", split="train")
tokenizer.pad_token = tokenizer.eos_token
# pre-processing the dataset a bit
def prefix_with_input(example):
example['chosen'] = example['input'] + " " + example['chosen']
example['rejected'] = example['input'] + " " + example['rejected'][0]
return example
train_dataset = train_dataset.map(prefix_with_input)
train_dataset = train_dataset.remove_columns(["input"])
# explicitly tokenizing the dataset
max_length = 8192
def tokenize_function(examples):
return tokenizer(examples["chosen"], max_length=max_length, padding='max_length', truncation=True)
train_dataset = train_dataset.map(tokenize_function, batched=True)
training_args = RewardConfig(
dataloader_pin_memory=False,
per_device_train_batch_size=1,
gradient_checkpointing=True,
gradient_accumulation_steps=4,
)
training_args.optimize_cuda_cache=True
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"lm_head",
]
)
trainer = RewardTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
peft_config=peft_config,
)
trainer.train()