And looking some more again, it looks like val_max_target_length
is used in generate
and overrides model.config.max_length
as you can see here:
So actually there is a working solution now that we know which of the 4 args is used to override max_length
.
I double checked that it is so with:
diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py
index 32a96555..7d8f4741 100644
--- a/examples/seq2seq/seq2seq_trainer.py
+++ b/examples/seq2seq/seq2seq_trainer.py
@@ -216,6 +216,10 @@ class Seq2SeqTrainer(Trainer):
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
}
+ logger.info(f"***** generate args *****")
+ for k, v in sorted(gen_kwargs.items()):
+ logger.info(f" {k} = {v}")
+
if self.args.predict_with_generate and not self.args.prediction_loss_only:
generated_tokens = self.model.generate(
inputs["input_ids"],
So getting:
2020-12-18 16:21:38 | INFO | seq2seq_trainer | ***** generate args *****
2020-12-18 16:21:38 | INFO | seq2seq_trainer | max_length = 50
2020-12-18 16:21:38 | INFO | seq2seq_trainer | num_beams = 4
So overriding is happening.
But why with self.data_args.val_max_target_length
only I don’t know.
So 2 possible things to do here:
- either add explicit
--min_gen_length
and--max_gen_length
args and pass those intogenerate
or at the very least document that--val_max_target_length
has double usage - one for validation dataset truncation and a secondary use forgenerate
'smax_length
override. - Perhaps that comment about use task specific params should be amended to say that further overrides may happen since the info logger doesn’t report that model.config.max_length was effectively set to
self.data_args.val_max_target_length
and thus it is confusing to the user.
I submitted a PR that addresses these 2 items above.
So this flurry of comments cleared out what cl arg to use to override max_length
, but I doubt it made any difference to your problem.
If the problem is still unresolved please help us at reproducing it. Ideally use the existing summarization datasets that we use for testing as explained here:
- cnn_db https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md#cnndailymail
- xsum https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md#xsum
or if that doesn’t work please make a small sample that reproduces the problem with your data and copy-n-paste instructions to get it and deploy it. Thank you!