Yep, only it’s being done for you in the model’s forward pass rather than the data collator! My understanding is that all of the ModelForTaskX
classes have default loss functions in their forward pass, which only get used if you include ‘labels’ in your inputs. And that these are what get used by Trainer
by default. So for example, if you check out the forward pass in the GPTJForCausalLM class, you’ll notice the exact same ‘shifting’ lines as the custom loss you noted above:
# from line 846
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))