Compute_metrics() behaves strangely in distributed setting

Hello!

I am using the HF Seq2SeqTrainer to fine-tune CodeT5 on a task. I run via torchrun and run on a single node with 2 GPUs.

I recently noticed that compute_metrics() function I wrote receives EvalPredictions instance with more sequences than there are in my validation set. The val set I provide to the Trainer is of size 250; whereas if you add

print(preds.shape)

into your compute_metrics() function, it prints (264, *) where * is variable depending on the longest prediction the model comes up with.

Upon further inspection, the excess [250,264) entries are just repeats of [1,14)? This issue disappears completely if I run without parallelisation.

Also, compute_metrics() seems to be getting called from both devices as the .shape information is printed twice (once from rank 0 and once from rank 1).

Can anyone suggest what is happening? :frowning: