Compute Perplexity using compute_metrics in SFTTrainer

How can I compute perplexity as a metric when using the SFTTrainer and log at end of each epoch, by using that in compute_metrics argument. I intend to pick the best checkpoint with least perplexity.

Here is the dimension of logits and labels that go into the compute_metrics function (50, 256, 50272) (total_records,seq_len_vocab_size).
and labels (50, 256).

How can I compute perplexity using a compute metrics function for a CasualLM task?

dataset = load_dataset(“imdb”)
trainer = SFTTrainer(
“facebook/opt-350m”,
train_dataset=dataset[“train”].select(range(50)),
eval_dataset = dataset[“test”].select(range(50)),
dataset_text_field=“text”,
max_seq_length=256,
compute_metrics = #logic goes here,
args = training_args,
)
trainer.train()

4 Likes

Hi, I wrote this function to log perplexity at each logging step.

def compute_metrics(pred: EvalPrediction):
predictions = torch.tensor(pred[0])

targets = torch.tensor(pred[1])

perplexity = Perplexity(ignore_index=-100)

perplexity_score = perplexity(predictions, targets)

return {"Perplexity": perplexity_score}

In trainer called like this:

trainer = SFTTrainer(
    **trainer_args,
    train_dataset=train_data,
    eval_dataset=valid_data,
    peft_config=peft_config ,
    dataset_text_field="text",
    max_seq_length=config.block_size,
    compute_metrics = utils.compute_metrics)
1 Like