Hi there,
I tried making a custom loss trainer for Learning Without Forgetting. However, it is taking hours longer then standard SFTTrainer would take and the losses printed out are unreasonable. I print loss every 10 steps and 10 steps training took 5 minutes to run. Training loss was 2332.57 and validation loss was 5.92, these values are incomparable (image attached). What am I doing wrong?
Blockquote
class LwFDataCollator:
def init(self, tokenizer, original_model, alpha=0.5):
self.alpha = alpha
# Initialize the LM collator (for padding AND labels)
# mlm=False indicates causal language modeling (predict next token)
self.lm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
def __call__(self, examples):
# Use LM collator to handle padding AND create the 'labels' field
# needed for loss calculation during both training and evaluation.
batch = self.lm_collator(examples)
# Add alpha for knowledge distillation loss calculation in LwFTrainer
# Note: alpha is popped in compute_loss, so it won't interfere with standard eval
batch["alpha"] = self.alpha
return batch
class LwFTrainer(SFTTrainer):
def init(self, *args, **kwargs):
super().init(*args, **kwargs)
# Keep a CPU copy of the initial model (teacher)
self.teacher_model = deepcopy(self.model).cpu().eval()
# Add **kwargs to accept unexpected arguments from the training loop
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
alpha = inputs.pop("alpha", 0.5)
device = next(model.parameters()).device
# Copy inputs for teacher
inputs_for_teacher = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()
}
# Student forward pass
outputs = model(**inputs)
student_logits = outputs.logits
# Task (language modeling) loss
shift_logits = student_logits[..., :-1, :].contiguous()
shift_labels = inputs["input_ids"][..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
task_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
#Teacher forward pass: Move teacher to GPU, run inference, move back to CPU
with torch.no_grad():
#Move teacher model to the same device as the student model
self.teacher_model.to(device)
#Pass original inputs (already on the correct device)
teacher_outputs = self.teacher_model(**inputs_for_teacher)
teacher_logits = teacher_outputs.logits #Already on the correct device
#Move teacher model back to CPU to save GPU memory
self.teacher_model.cpu()
# Knowledge-distillation loss
shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()
temperature = 2.0
kd_loss = F.kl_div(
F.log_softmax(shift_logits / temperature, dim=-1),
F.softmax(shift_teacher_logits / temperature, dim=-1),
reduction="batchmean"
) * (temperature ** 2)
# Combine task + KD
loss = alpha * task_loss + (1 - alpha) * kd_loss
return (loss, outputs) if return_outputs else loss
Blockquote
import transformers
from unsloth import is_bfloat16_supported
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Step 6: Configure the training arguments and start training
training_args = transformers.TrainingArguments(
output_dir=“./lwf_finetuned_model”,
dataloader_pin_memory=False,
num_train_epochs=2,
per_device_train_batch_size=8,
gradient_accumulation_steps=16,
learning_rate=2e-4,
weight_decay=0.01,
warmup_steps=100,
logging_steps=10,
label_names = ,
# — Start Evaluation Arguments —
evaluation_strategy=“steps”, # Evaluate every eval_steps
eval_steps=10, # Evaluation frequency (match save_steps)
save_strategy=“steps”, # Save checkpoints at the same frequency as evaluation
save_steps=100, # Save frequency (should generally match eval_steps)
load_best_model_at_end=True, # Load the best model based on eval loss at the end
# metric_for_best_model=“loss”, # Optional: defaults to loss
save_total_limit=2, # Optional: Limit the total number of checkpoints saved
# — End Evaluation Arguments —
fp16=False,
bf16=True,
report_to=“none”
)
Create data collator with LwF
lwf_data_collator = LwFDataCollator(tokenizer, model, alpha=0.7)
Initialize trainer with custom components
trainer = LwFTrainer(
model=model,
args=training_args,
train_dataset=formatted_dataset_dict[‘train’],
eval_dataset=formatted_dataset_dict[‘test’],
tokenizer=tokenizer,
data_collator=lwf_data_collator,
max_seq_length=max_seq_length,
#dataset_text_field=“text”,
packing=False
)
trainer.can_return_loss = True