Early stopping implementation in accelerate?

Is it possible to have an implementation of early stopping while using Accelerate? I know accelerate handles distributed training for normal pytorch training loops, but I’m not quite sure how to handle early stopping since one process could meet the early stop criteria and another may not. I was thinking of something like this:

for epoch in range(num_epochs):
    for batch in train_dataloader:
        outputs = my_model(**batch)
        loss = outputs['loss']    
    metric = my_eval(my_model, dev_dataloader)  # evalution on dev set (i.e., holdout from training)
    if my_early_stop.step(metric):
        break  # early stop criterion is met, we can stop now

Why would the process see different metrics? They’ll all have the same one normally.

@sgugger Thanks for your response. Maybe my understanding of accelerate is incorrect, but I thought each process saw different slices of the training and dev sets. Each process would compute the same metric but on different slices of the datasets.

In which case you should gather the tensors before feeding them to your metric function, as is done in all examples.

Ah yes, thank you! Apologies for the simple question. I’m still learning Accelerate. Thank you for your help!