Is it possible to have an implementation of early stopping while using Accelerate? I know accelerate handles distributed training for normal pytorch training loops, but I’m not quite sure how to handle early stopping since one process could meet the early stop criteria and another may not. I was thinking of something like this:
for epoch in range(num_epochs):
for batch in train_dataloader:
optimizer.zero_grad()
outputs = my_model(**batch)
loss = outputs['loss']
my_accelerator.backward(loss)
optimizer.step()
metric = my_eval(my_model, dev_dataloader) # evalution on dev set (i.e., holdout from training)
if my_early_stop.step(metric):
break # early stop criterion is met, we can stop now