Hello,
I am again adapting the run_glue_no_trainer.py script and using my own version of early stopping. I am training a model on a single batch of data, i.e. using 32 samples and a per_device_batch_size of 32. Also, I am training on 2
What caught my eye is that when printing validation loss, I gets printed twice:
print(f'Eval loss gathered:')
print(accelerator.gather_for_metrics(eval_loss))
results in
Eval loss gathered:
tensor([126.0983, 125.1440], device='cuda:1')
Eval loss gathered:
tensor([126.0983, 125.1440], device='cuda:0')
Getting 2 print statements makes sense as I am using 2 GPUs. What doesnât make sense is the difference in the eval_loss. As no dropout is being applied, I expect this to come from updated weights.
What I am ultimately interested in is what things I need to consider when doing single batch training on multi gpu.
- For example, in the above situation, are the model weights updated twice in a single step?
- How do I handle the different losses e.g. when using early stopping?
- I am not 100% sure which paths in the programming code the processes take and at which points sync happen. Is there any information on that?
Some clarification would be greatly appreciated.
Cheers!
EDIT: i should extend my questions, as there has been something else suprising:
I implement my early stopping as following:
class early_stopping_callback:
def __init__(self,min_delta=0,patience=5):
self.min_delta=min_delta
self.patience=patience
self.counter=0
self.lowest_loss=float('inf')
def check_early_stopping(self,eval_loss):
delta = self.lowest_loss - eval_loss
print(f"DEVICE {torch.cuda.current_device()}: current eval_loss {eval_loss}, lowest eval_loss {self.lowest_loss}, delta {delta} ")
if delta >= self.min_delta:
self.lowest_loss = eval_loss
# self.lowest_loss_index = -1
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
return True
return False
I make calls to the function as follows:
eval_loss = 0
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
eval_loss += outputs.loss.detach().float()
...
...
if es_callback.check_early_stopping(eval_loss.item()):
print(f"Stopping early after epoch {epoch}")
accelerator.set_trigger()
What is suprising to me is that there seems to be one early stopping object for each process:
Step 1: ...
DEVICE 0: current eval_loss 126.09870910644531, lowest eval_loss inf, delta inf
DEVICE 1: current eval_loss 125.14447021484375, lowest eval_loss inf, delta inf
...
DEVICE 1: current eval_loss 125.14430236816406, lowest eval_loss 125.14447021484375, delta 0.0001678466796875
DEVICE 0: current eval_loss 126.0985107421875, lowest eval_loss 126.09870910644531, delta 0.0001983642578125
...
This probably makes sense in some way but I am not proficient in how accelerator works under the hood. Therefore, somebody may please shed light on where the problem and answer lies. Thanks!