Compute Perplexity using compute_metrics in SFTTrainer

How can I compute perplexity as a metric when using the SFTTrainer and log at end of each epoch, by using that in compute_metrics argument. I intend to pick the best checkpoint with least perplexity.

Here is the dimension of logits and labels that go into the compute_metrics function (50, 256, 50272) (total_records,seq_len_vocab_size).
and labels (50, 256).

How can I compute perplexity using a compute metrics function for a CasualLM task?

dataset = load_dataset(“imdb”)
trainer = SFTTrainer(
eval_dataset = dataset[“test”].select(range(50)),
compute_metrics = #logic goes here,
args = training_args,

1 Like