Wav2Vec2 Loss Function Question

So I am working Wav2Vec2 so I can apply it to a new dataset. To do this, I have mainly been looking at the example trainer provided:

My main question is this:

In the model, when it computes the loss, it adds together all the losses across all the samples. This is seen in both the cross entropy loss for contrastive loss (reduction=“sum”) and the diversity loss, where it sums across the the codevectors and the codebooks.

So the actual loss that we are doing loss.backward() on is pretty large (in the thousands)!

On the other hand, when we the loss is printed out to the console it starts out around 4. Upon further inspection, I see that they divide the loss by the number of masked tokens (basically the average at this point). So if we are printing out the loss with a value around 4, but it is actually training on a loss that is in the thousands, is this correct? Is it ok to have such a large loss value to train on, and why arent we just taking the average before doing loss.backward()?

This might be a dumb question but thanks!!

So now that i am looking more carefully, there is also this line of code:

if accelerator.state.num_processes > 1:
                num_losses = accelerator.gather_for_metrics(num_losses).sum()
                gradient_multiplier = accelerator.state.num_processes / num_losses
                multiply_grads(model.module.parameters(), gradient_multiplier)
else:
         multiply_grads(model.parameters(), 1 / num_losses)

Now it makes sense to me. We dont know the number of losses ahead of time across all GPUs because we randomly mask, i guess it effectively is doing the same thing then