BART generate() output not related to input

I’m in need of a model that can fill in multi-token masks, and it seems like BART is the best choice at the moment.
I’ve trained a BART model from scratch using a custom collator that masks out sequential groups of tokens.

Unfortunately, when using the generate() method on my BartForConditionalGeneration model, the model output is completely unrelated to the input (with or without mask tokens). The output clearly stems from the dataset distribution, but is not at all related to the input, let alone has filled masks.
Now, if I obtain the logits directly (i.e., model(input_ids).logits), the predictions for the token in the input look perfectly fine, so the model seems to have learned something.

The model was trained as follows:
Input: <s>This is <mask></s>
Label: <s>This is some input</s>
I let the library generate the decoder input, it should look something like: </s><s>This is some input.

I wonder if I’m missing something obvious here?

1 Like

Did you make any progress with this? I’m also looking to pretrain BART from scratch for infill generation.