Understanding the encoder-decoder loss calculation VS CLM loss

Hello!

I’m trying to understand th encoder-decoder loss, from modeling_encoder_decoder.py line 640, the loss is computed “independent from decoder (as some shift the logits inside them)”: why do we NOT need to shift the logits for next token prediction just like in CLM (e.g. modeling_bert.py line 1255)?

In the loss implemented in modeling_bert.py line 1255:

# we are doing next-token prediction; shift prediction scores and input ids by one 
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

we align the predicted_scores with the right-shifted labels that are one step ahead by labels=labels[:,1:] and shifted_prediction_scores = prediction_scores[:, :-1, :], so we are predicting the next token. However, from the loss in modeling_encoder_decoder.py line 640:

logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))

it is asking the model to learn to reconstruct the shifted labels instead of predicting the next. If my understanding here is correct, how would it be able to generate next tokens from a decoder_start_token (like in any seq2seq model)?

The reason I’m asking this is because I’m implementing a custom encoder-decoder following modeling_encoder_decoder.py and am training everything from scratch, and during generation it is always outputing the decoder_start_token, and I think it’s because the model learns to reconstruct the decoder_input_id rather than predicting the next. Since the first token it sees is the decoder_start_token, it reconstructs it, append to decoder_input_id, generate with the updated decoder_input_id, take the last item in the output_sequence (which is still a decoder_start_token), append again and so on. In the end I just get a sequence of decoder_start_token.

Could someone please help me understand the procedure or point out where I made a mistake?

Thank you very much :))