How to force bos_token_id for each example individually in MBart?

Say I have a batch of examples with fields of input_ids of size m*n and bos_token_id of size n. Is there a way that I could specify the bos_token_id for each example during the evaluation step when using generate?

2 Likes

I’m also curious about this. @mralexis - did you ever work this out? It seems like a similar question was also asked here: M2M model finetuning on multiple language pairs which also had no reply.

1 Like

I think I managed to do this, but my way of doing it is really hacky and fragile so I wouldn’t recommend it. I’ve filed a feature request with the huggingface transformers team to improve this at https://github.com/huggingface/transformers/issues/15500

That feature request has a link to a Colab notebook with the code for how I did it

1 Like

Hey @nfortescue,

I tried your code, it works when I’m just training. But seems like it runs into an error when I enable the evaluation during the training for the following code.
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length, AttributeError: 'M2MSeq2SeqTrainer' object has no attribute '_max_length'

  # XXX: adapt synced_gpus for fairscale as well
  gen_kwargs = {
    "max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
    "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
    "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
  }

After changing the gen_kwargs, the issue was bypassed but subsequently there was another error TypeError: forward() got an unexpected keyword argument 'forced_bos_token_id' which arose from the following code line:

    with torch.no_grad():
      with self.autocast_smart_context_manager():
        outputs = model(**inputs)

Which I resolved by removing the ‘forced_bos_token_id’ temporarily from the inputs before calling the model to generate the output. However, would that mean that the bos token of the target sequence is now incorrect?