Single batch training on multi-gpu

Okay, I think I found a good way to solve this.

The origin of my confusion came from the fact that I did not know what excactly Accelerate was doing under the hood.

So (in my case) accelerate is using PyTorchs DistributedDataParallel (DDP) to clone models across the GPUs. The available batches in the training dataloader are distributed among the devices. Therefore, I have to choose my batch size to be batch_size = total_batch_size / num_gpus (see here). The difference in observed validation loss comes from the fact that each device had a dififerent subset of the dataloader.

Then, to assemble the validation loss for early_stopping one has to collect the losses from the devices as follows:

global_eval_loss = torch.mean(accelerator.gather_for_metrics(eval_loss)).item()
if es_callback.check_early_stopping(global_eval_loss):
    logger.info(f"Stopping early after epoch {epoch}")
    accelerator.set_trigger()

Hope this helps anyone who runs into the same problem.

4 Likes