so the reason for two
bos tokens in beam search is that, here the generate function sets the
decoder_start_token_id (if not defined then
bos_token_id , which
<s> for BART) as the prefix token,
and this forces the generation of
<s>) when the current length is one.
So I also faced the same issue for my bart model fine-tuned for question generation when using beam search. This is what it produced for one of the inputs.
['<s><s> created Python in 1991?']
So when I just returned logits as it is from
adjust_logits_during_generation instead of forcing another
bos_token it generated this
['<s>Who created Python?']
Which is what I expected.
But as @sshleifer said, when I modified
adjust_logits_during_generation as said above it breaks the
RUN_SLOW=1 pytest tests/test_modeling_bart.py because the generated summaries don’t match the fairseq ones.
So when I looked at the configs of the bart-large-cnn and bart-large-xsum, it used
</s>) as the
decoder_start_token_id and test only pass when the first two token are
</s> is set as first token by
<s> is set by
So I suspected that the weird beam search output for my model is because of using
decoder_start_token_id, when I tried using
</s> as the start token like the summerization models it again gave weird results.
['</s><s> created Python in 1991?']
So to me it seems that
</s> as the
decoder_start_token is only working for the pre-trained summrization models and not for other bart-large or bart-base finetuned models.
@chrisdoyleIE can you also confirm this with your model
- first do beam search using
eos_token_id as the
- use the bos_token as start token (as usual) and return logits as it is from
adjust_logits_during_generation instead of forcing
<s>. You can do it using
def adjust_logits(logits, **kwargs):
model.adjust_logits_during_generation = adjust_logits
</s> is the
decoder_start_token then why don’t we set it when fine-tuning the model ? Also in fairseq they seem to use the bos token
<s> as the first prefix. Here
Sorry for this long issue, but no other way to explain it.