Gradient accumulation gives different results compared to full batch

I’m building an image captioner using Huggingface models with Pytorch and I’m getting different results for the first iteration (and for the following iterations obviously) when the effective batch size is the same (mini batch of 4 with 16 gradient accumulation steps vs 64 batch size vs any other combination that leads to 64 effective batch size).

For a batch size of 16 with 4 gradient accumulation steps the loss of the first iteration is 4.095348954200745 and for a batch size of 4 with 16 gradient accumulation steps the loss of the first iteration is 4.097771629691124.

Here’s a gist link to the code I’m running.

I’ve also included the code because I had to make some slight modifications for it to accept CLIP models.

I’ve disabled Dropout, set the seed, didn’t shuffle the dataloader, and the models I’m using don’t have batch specific layers like batchnorm layers as far as I know. The model’s im using are gpt2 and laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K (in case you want to look at the models). I also don’t think it’s a precision issue because in the code I look at the sum of the loss divided by the grad acc steps vs the loss divided by the grad acc steps and them summed and it leads to the exact same result…

In the same gist I added an attempt at re creating this issue with pure Pytorch and the MNIST dataset so you can more easily run it. This is the file. The loss difference isn’t as large as the one I just described, but it’s still a difference that shouldn’t be there in my opinion since there are no precision issues (I think)

I think I figured it out. Essentially the “problem” was that I was using mean reduction in my loss when training a model with variable sequence length. If I have 2 sequences, A and B, and sequence A has 7 tokens and sequence B has 10 tokens then I have to add 3 padding tokens to A. The loss of these 2 sequences in a batch would be (loss_A + loss_B)/17. If I was using gradient accumulation with mean reduction the loss would be loss_A/7 + loss_B/10. It’s easy to tell that these 2 are not the same.

To avoid this issue, would it make sense to change the reduction to sum and only divide the loss by the number of tokens in the effective batch at the end? I think this would more closely (and by that I mean exactly up to precision errors) reflect the “real” loss of the effective batch, right?