it still continues to generate many more tokens than it should
That was exactly my observation too. Which led me to think that somehow the model is not learning EOS character (hence, the generation is not functioning as expected).
Re. prefix:
Looks like the prefix is set here: https://github.com/huggingface/transformers/blob/master/examples/seq2seq/utils.py#L243
which seems like it’s passed here:
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
pickle_save(self.hparams, self.hparams_save_path)
self.step_count = 0
self.metrics = defaultdict(list)
self.model_type = self.config.model_type
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=self.model.config.prefix or "",
)
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
"test": self.hparams.n_test,
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
self.target_lens = {
"train": self.hparams.max_target_length,
Where self.model.config.prefix
is being picked up from? Not sure.