I’m in need of a model that can fill in multi-token masks, and it seems like BART is the best choice at the moment.
I’ve trained a BART model from scratch using a custom collator that masks out sequential groups of tokens.
Unfortunately, when using the
generate() method on my
BartForConditionalGeneration model, the model output is completely unrelated to the input (with or without mask tokens). The output clearly stems from the dataset distribution, but is not at all related to the input, let alone has filled masks.
Now, if I obtain the logits directly (i.e.,
model(input_ids).logits), the predictions for the token in the input look perfectly fine, so the model seems to have learned something.
The model was trained as follows:
<s>This is <mask></s>
<s>This is some input</s>
I let the library generate the decoder input, it should look something like:
</s><s>This is some input.
I wonder if I’m missing something obvious here?