What should decoder_input_ids be when pre-training mBART?

A simple question: when pre-training (denoising) with mBART-50, should decoder_input_ids be based upon labels (the correct text) or upon input_ids (the corrupted text)?

In some reference code I found, decoder_input_ids appears to be input_ids shifted right: batch[“decoder_input_ids”] = self.shift_tokens_right(batch[“input_ids”]).

However, from many other sources, it seems that decoder_input_ids is always based on labels, from which the model learns via cross-attention. (See also possible mistake in documentation · Issue #11357 · huggingface/transformers · GitHub where it is claimed that the mention of input_ids is a bug, as labels is actually used).

So which one is it? I would be immensely grateful for any clarification or guidance!

1 Like