Encoder Decoder Loss

Hi all,
I was reading through the encoder decoder transformers and saw how loss was generated. But I’m just wondering how it is internally generated?

Is it something like the following: Suppose I have the following pair: ("How are you?", "I am doing great"). In this case, is it calculating the cross entropy loss for the four output tokens and then averaging them?

I posted a longer version of above with reproducible code: deep learning - Sequence to Sequence Loss - Stack Overflow

Hi @sachin

the EncoderDecoder model calculates the standard auto-regressive cross-entropy loss using the labels i.e the output sequence. It just shifts the labels inside the models before computing the loss.

It’s the same loss used in other seq2seq models like BART, T5, and decoder models like GPT2.

Hope this helps.

perfect. I forgot to shift the labels.

Just to extend the question though, I looked at the source code and managed to replicate the loss by huggingface, but just wondering shouldn’t it be:

# current loss calculation
output_logits = logits[:,:-1,:]
output_mask = mask[:,:-1]
label_tokens = output_tokens["input_ids"][:, 1:].unsqueeze(-1)
select_logits = torch.gather(output_logits, -1, label_tokens).squeeze()
huggingface_loss = -select_logits.mean()

# proposed loss instead:
seq_loss = (select_logits * output_mask).sum(dim=-1, keepdims=True) / output_mask.sum(dim=-1, keepdims=True)
seq_loss = -seq_loss.mean()

Happy to create a PR if you agree.

The biggest downside of existing loss IMO is that if there is large variation in a batch output lengths it will focus on padding after the token.

padding tokens in the labels should be replaced by -100 so the cross_entriopy loss ignores the pad tokens when computing the loss.

and the loss is actually computed like this

shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() # prediction_scores is logits
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

Hello, I am unsure if I should ask this here or if I should create a separate post but I was looking that the way the loss is computed and it seems really confusing to me how the logits are shifted and why is it done in such a way? I have been looking online and I haven’t managed to find a proper explanation, so could you please help me by explaining why and how the logit shifting is done?