perfect. I forgot to shift the labels.
Just to extend the question though, I looked at the source code and managed to replicate the loss by huggingface, but just wondering shouldn’t it be:
# current loss calculation
output_logits = logits[:,:-1,:]
output_mask = mask[:,:-1]
label_tokens = output_tokens["input_ids"][:, 1:].unsqueeze(-1)
select_logits = torch.gather(output_logits, -1, label_tokens).squeeze()
huggingface_loss = -select_logits.mean()
# proposed loss instead:
seq_loss = (select_logits * output_mask).sum(dim=-1, keepdims=True) / output_mask.sum(dim=-1, keepdims=True)
seq_loss = -seq_loss.mean()
Happy to create a PR if you agree.
The biggest downside of existing loss IMO is that if there is large variation in a batch output lengths it will focus on padding after the token.