What is the format of labels for mBART-50?

From the model card, it seems that mBART-50 (in contrast to mBART), expects the input in format [lang_code] X [eos] with X being the text.

However, when doing:

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50")
data_collator = DataCollatorForSeq2Seq(
        ...
        model=model,
       ...
    )

then model.prepare_decoder_input_ids_from_labels will be called (in this case MBartForConditionalGeneration.prepare_decoder_input_ids_from_labels), which in turn will call shift_tokens_right, and here is my confusion.

From reading the comment in source, it seems the function was intended for the older mBART format, where sentences end with language ID.

Shift input ids one token to the right, and wrap the last non pad token (the <LID> token)

Will it still work with newer mBART-50, where the last non pad token is eos instead?

1 Like