Infilling multiple mask spans with BartForConditionalGeneration

I’m using BartForConditionalGeneration to do sentence infilling and while I find it works for single <mask> tokens I can’t seem to adapt it to a multiple-mask context.

The problem seems to be that once it starts generating, it doesn’t “return” to the original sequence, but rather continues generating until it hits a stopping condition. Often times it will actually generate the first 2 or 3 tokens immediately following the <mask>, but soon after that it diverges again—i.e., it’s clearly not returning to inference-only, but rather actively generating and coincidentally replicating the original context following the mask (which is to be expected, since the generation is conditioned on that context).

I’m guessing this has something to do with how generate() works for this model, but I’m wondering if there’s a way around it? That is, is there a way to alternate dynamically between actively generating infilling tokens and passing through the unmasked tokens from the original input as context?

The only other solution I can think of is to run multiple iterations of generate(), each with only a single <mask> token, then compiling the result when they’re all done. Obviously this would be slower, but I’d expect it to work.

Any help or thoughts very much appreciated.