Gradient accumulation loss compute

Hello, everyone!
Suppose we have data [b,s,dim], I recently noticed that CrossEntropyLoss is (1) computed the average on all tokens (b * s) in a batch instead of (2) computing on each sentence and then compute the average.
Here is the code to compute loss for hugging_face transformers BertLMHeadModel

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)
        lm_loss = None
        if labels is not None:
            # we are doing next-token prediction; shift prediction scores and input ids by one
            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            lm_loss  = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

I know that (1) and (2) have no difference in this situation. But when we apply gradient accumulation, I think the situation is different.

Suppose I have a batch_size 4, and the lengths of 4 sentences are 100,200,300,400.
With batch_size 4, the loss is computed on the average of total 1000 tokens.

But with batch_size = 1 and gradient accumulation = 4, I think the loss is different. We first compute the loss on each sentence separately and then compute the average, which means for the sentence of 100 tokens, we compute the loss average of 100 tokens and then divide by 4 and add it to total loss, the same for other 3 sentences, and I think the loss computed this way is different from that computed with batch_size =4.

Did I misunderstand something?

1 Like