Compute Perplexity using compute_metrics in SFTTrainer

How can I compute perplexity as a metric when using the SFTTrainer and log at end of each epoch, by using that in compute_metrics argument. I intend to pick the best checkpoint with least perplexity.

Here is the dimension of logits and labels that go into the compute_metrics function (50, 256, 50272) (total_records,seq_len_vocab_size).
and labels (50, 256).

How can I compute perplexity using a compute metrics function for a CasualLM task?

dataset = load_dataset(“imdb”)
trainer = SFTTrainer(
“facebook/opt-350m”,
train_dataset=dataset[“train”].select(range(50)),
eval_dataset = dataset[“test”].select(range(50)),
dataset_text_field=“text”,
max_seq_length=256,
compute_metrics = #logic goes here,
args = training_args,
)
trainer.train()

1 Like