I was reading through the encoder decoder transformers and saw how loss was generated. But I’m just wondering how it is internally generated?
Is it something like the following: Suppose I have the following pair:
("How are you?", "I am doing great"). In this case, is it calculating the cross entropy loss for the four output tokens and then averaging them?
I posted a longer version of above with reproducible code: deep learning - Sequence to Sequence Loss - Stack Overflow
EncoderDecoder model calculates the standard auto-regressive cross-entropy loss using the
labels i.e the output sequence. It just shifts the
labels inside the models before computing the loss.
It’s the same loss used in other seq2seq models like BART, T5, and decoder models like GPT2.
Hope this helps.
perfect. I forgot to shift the labels.
Just to extend the question though, I looked at the source code and managed to replicate the loss by huggingface, but just wondering shouldn’t it be:
# current loss calculation
output_logits = logits[:,:-1,:]
output_mask = mask[:,:-1]
label_tokens = output_tokens["input_ids"][:, 1:].unsqueeze(-1)
select_logits = torch.gather(output_logits, -1, label_tokens).squeeze()
huggingface_loss = -select_logits.mean()
# proposed loss instead:
seq_loss = (select_logits * output_mask).sum(dim=-1, keepdims=True) / output_mask.sum(dim=-1, keepdims=True)
seq_loss = -seq_loss.mean()
Happy to create a PR if you agree.
The biggest downside of existing loss IMO is that if there is large variation in a batch output lengths it will focus on padding after the token.
padding tokens in the labels should be replaced by -100 so the
cross_entriopy loss ignores the pad tokens when computing the loss.
and the loss is actually computed like this
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() # prediction_scores is logits
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
Hello, I am unsure if I should ask this here or if I should create a separate post but I was looking that the way the loss is computed and it seems really confusing to me how the logits are shifted and why is it done in such a way? I have been looking online and I haven’t managed to find a proper explanation, so could you please help me by explaining why and how the logit shifting is done?
Hey, sorry for not replying earlier. The basic reason is because when the tokenizer encodes it, it will do something like
"<START> My decoded sentence <END>". The output of the decoder transformer will only predict for
"My decoded sentence <END>".
So the logits predict for the tokens shifted by one (without the token). And the reason we look at logits except for last one is the one value it predicts after is non-sensical, so we simply drop it.