Here is a minimal example.

For the evaluation function,

```
def perplexity_from_logits(
logits: torch.FloatTensor,
labels: torch.IntTensor,
shift: bool = True,
normalize: bool = True,
) -> Union[float, List[float]]:
if shift:
labels = labels[..., 1:]
logits = logits[..., :-1, :]
with torch.no_grad():
if normalize:
perplexity = torch.exp(
cross_entropy(logits.permute(0, 2, 1), labels)
).item()
else:
perplexity = torch.exp(
cross_entropy(logits.permute(0, 2, 1), labels, reduction='none')
).mean(dim=-1).tolist()
return perplexity
def compute_metric(eval_preds: EvalPrediction) -> Dict[str, Any]:
(logits, hidden_states), labels = eval_preds
return {'perplexity': perplexity_from_logits(logits=logits, labels=labels)}
```

For the training parameters,

```
model = T5ForConditionalGeneration.from_pretrained('t5-small')
training_args = Seq2SeqTrainingArguments(
output_dir='tmp',
evaluation_strategy='steps',
eval_steps=1,
eval_accumulation_steps=1,
save_total_limit=None,
max_steps=10,
seed=42,
prediction_loss_only=False
do_eval=True,
)
trainer = Seq2SeqTrainer(
model=model,
data_collator=collator,
args=training_args,
train_dataset=dataset_train,
eval_dataset=dataset_val,
compute_metrics=compute_metric,
)
trainer.train()
```

In my use case, `dataset_val`

has 249459 samples. The dataset is too complicated and large to share here.

Since the evaluation slowdowns as it progresses, it might either be large evaluation set or memory leakage (unlikely) coming into play. Any insight into working on large evaluation set would be helpful.