Custom loss trainer takes hours and validation loss starts so differently then test loss (Learning Without Forgetting)

Hi there,
I tried making a custom loss trainer for Learning Without Forgetting. However, it is taking hours longer then standard SFTTrainer would take and the losses printed out are unreasonable. I print loss every 10 steps and 10 steps training took 5 minutes to run. Training loss was 2332.57 and validation loss was 5.92, these values are incomparable (image attached). What am I doing wrong?

Blockquote

class LwFDataCollator:
def init(self, tokenizer, original_model, alpha=0.5):
self.alpha = alpha
# Initialize the LM collator (for padding AND labels)
# mlm=False indicates causal language modeling (predict next token)
self.lm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

def __call__(self, examples):
    # Use LM collator to handle padding AND create the 'labels' field
    # needed for loss calculation during both training and evaluation.
    batch = self.lm_collator(examples)

    # Add alpha for knowledge distillation loss calculation in LwFTrainer
    # Note: alpha is popped in compute_loss, so it won't interfere with standard eval
    batch["alpha"] = self.alpha
    return batch

class LwFTrainer(SFTTrainer):
def init(self, *args, **kwargs):
super().init(*args, **kwargs)
# Keep a CPU copy of the initial model (teacher)
self.teacher_model = deepcopy(self.model).cpu().eval()

# Add **kwargs to accept unexpected arguments from the training loop
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
    alpha = inputs.pop("alpha", 0.5)

    device = next(model.parameters()).device

    # Copy inputs for teacher
    inputs_for_teacher = {
        k: v.clone() if isinstance(v, torch.Tensor) else v
        for k, v in inputs.items()
    }

    # Student forward pass
    outputs = model(**inputs)
    student_logits = outputs.logits

    # Task (language modeling) loss
    shift_logits = student_logits[..., :-1, :].contiguous()
    shift_labels = inputs["input_ids"][..., 1:].contiguous()
    loss_fct = torch.nn.CrossEntropyLoss()
    task_loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    )

    #Teacher forward pass: Move teacher to GPU, run inference, move back to CPU
    with torch.no_grad():
        #Move teacher model to the same device as the student model
        self.teacher_model.to(device)
        #Pass original inputs (already on the correct device)
        teacher_outputs = self.teacher_model(**inputs_for_teacher)
        teacher_logits = teacher_outputs.logits #Already on the correct device
        #Move teacher model back to CPU to save GPU memory
        self.teacher_model.cpu()

    # Knowledge-distillation loss
    shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()
    temperature = 2.0
    kd_loss = F.kl_div(
        F.log_softmax(shift_logits / temperature, dim=-1),
        F.softmax(shift_teacher_logits / temperature, dim=-1),
        reduction="batchmean"
    ) * (temperature ** 2)

    # Combine task + KD
    loss = alpha * task_loss + (1 - alpha) * kd_loss
    return (loss, outputs) if return_outputs else loss

Blockquote
import transformers
from unsloth import is_bfloat16_supported

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

Step 6: Configure the training arguments and start training

training_args = transformers.TrainingArguments(
output_dir=“./lwf_finetuned_model”,
dataloader_pin_memory=False,
num_train_epochs=2,
per_device_train_batch_size=8,
gradient_accumulation_steps=16,
learning_rate=2e-4,
weight_decay=0.01,
warmup_steps=100,
logging_steps=10,
label_names = ,
# — Start Evaluation Arguments —
evaluation_strategy=“steps”, # Evaluate every eval_steps
eval_steps=10, # Evaluation frequency (match save_steps)
save_strategy=“steps”, # Save checkpoints at the same frequency as evaluation
save_steps=100, # Save frequency (should generally match eval_steps)
load_best_model_at_end=True, # Load the best model based on eval loss at the end
# metric_for_best_model=“loss”, # Optional: defaults to loss
save_total_limit=2, # Optional: Limit the total number of checkpoints saved
# — End Evaluation Arguments —
fp16=False,
bf16=True,
report_to=“none”
)

Create data collator with LwF

lwf_data_collator = LwFDataCollator(tokenizer, model, alpha=0.7)

Initialize trainer with custom components

trainer = LwFTrainer(
model=model,
args=training_args,
train_dataset=formatted_dataset_dict[‘train’],
eval_dataset=formatted_dataset_dict[‘test’],
tokenizer=tokenizer,
data_collator=lwf_data_collator,
max_seq_length=max_seq_length,
#dataset_text_field=“text”,
packing=False
)
trainer.can_return_loss = True

1 Like