Single batch training on multi-gpu

Hello,

I am again adapting the run_glue_no_trainer.py script and using my own version of early stopping. I am training a model on a single batch of data, i.e. using 32 samples and a per_device_batch_size of 32. Also, I am training on 2

What caught my eye is that when printing validation loss, I gets printed twice:

print(f'Eval loss gathered:')
print(accelerator.gather_for_metrics(eval_loss))

results in

Eval loss gathered:
tensor([126.0983, 125.1440], device='cuda:1')
Eval loss gathered:
tensor([126.0983, 125.1440], device='cuda:0')

Getting 2 print statements makes sense as I am using 2 GPUs. What doesn’t make sense is the difference in the eval_loss. As no dropout is being applied, I expect this to come from updated weights.

What I am ultimately interested in is what things I need to consider when doing single batch training on multi gpu.

  1. For example, in the above situation, are the model weights updated twice in a single step?
  2. How do I handle the different losses e.g. when using early stopping?
  3. I am not 100% sure which paths in the programming code the processes take and at which points sync happen. Is there any information on that?

Some clarification would be greatly appreciated.
Cheers!

EDIT: i should extend my questions, as there has been something else suprising:

I implement my early stopping as following:

class early_stopping_callback:
  def __init__(self,min_delta=0,patience=5):
    self.min_delta=min_delta
    self.patience=patience
    self.counter=0
    self.lowest_loss=float('inf')
  def check_early_stopping(self,eval_loss):
    delta =  self.lowest_loss - eval_loss
    print(f"DEVICE {torch.cuda.current_device()}: current eval_loss {eval_loss}, lowest eval_loss {self.lowest_loss}, delta {delta} ")
    if delta >= self.min_delta:
      self.lowest_loss = eval_loss
    #   self.lowest_loss_index = -1
      self.counter = 0
    else:
      self.counter += 1
      if self.counter >= self.patience:
        return True
    return False

I make calls to the function as follows:

eval_loss = 0
for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                outputs = model(**batch)
                eval_loss += outputs.loss.detach().float()
              
                ...
                ...
if es_callback.check_early_stopping(eval_loss.item()):
            print(f"Stopping early after epoch {epoch}")
            accelerator.set_trigger()

What is suprising to me is that there seems to be one early stopping object for each process:

Step 1: ...
DEVICE 0: current eval_loss 126.09870910644531, lowest eval_loss inf, delta inf 
DEVICE 1: current eval_loss 125.14447021484375, lowest eval_loss inf, delta inf 

...

DEVICE 1: current eval_loss 125.14430236816406, lowest eval_loss 125.14447021484375, delta 0.0001678466796875
DEVICE 0: current eval_loss 126.0985107421875, lowest eval_loss 126.09870910644531, delta 0.0001983642578125

...

This probably makes sense in some way but I am not proficient in how accelerator works under the hood. Therefore, somebody may please shed light on where the problem and answer lies. Thanks!

Okay, I think I found a good way to solve this.

The origin of my confusion came from the fact that I did not know what excactly Accelerate was doing under the hood.

So (in my case) accelerate is using PyTorchs DistributedDataParallel (DDP) to clone models across the GPUs. The available batches in the training dataloader are distributed among the devices. Therefore, I have to choose my batch size to be batch_size = total_batch_size / num_gpus (see here). The difference in observed validation loss comes from the fact that each device had a dififerent subset of the dataloader.

Then, to assemble the validation loss for early_stopping one has to collect the losses from the devices as follows:

global_eval_loss = torch.mean(accelerator.gather_for_metrics(eval_loss)).item()
if es_callback.check_early_stopping(global_eval_loss):
    logger.info(f"Stopping early after epoch {epoch}")
    accelerator.set_trigger()

Hope this helps anyone who runs into the same problem.

3 Likes