Supervised Fine-tuning Trainer - Loss function calculation

I am attempting to create a custom loss function by subclassing the SFTTrainer. As an initial test that it works I’m using a conventional cross-entropy loss as part of the custom function.

from torch import nn
from trl import SFTTrainer

class CustomSFTTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super(CustomSFTTrainer, self).__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
                   
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))

        return (loss, outputs) if return_outputs else loss
    
trainer = CustomSFTTrainer(
    model=model,
    train_dataset=dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=format_instruction,
    args=args,
)
trainer.train()
        

This gives different losses and results to running the same training with the ‘out-the-box’ SFTTrainer as below:

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=format_instruction,
    args=args,
)

trainer.train()

How is the loss function calculated in the regular SFTTrainer that would be causing these differences?

2 Likes