Generation Config for ByT5

I am training a ByT5 model for a generation task. My initial observation was that there are a lots of repeats and the generation quality is bad. So I started doing a contrastive search mentioned here: https://huggingface.co/blog/introducing-csearch. But t seems I cannot override the generation config…

The steps I follow are:
I initialize the model using:

model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir
    )

Next, my trainer looks like this:

trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset = train_dataset['train'] if training_args.do_train else None,
        eval_dataset = eval_dataset['train'] if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
    )

in the above, I try to override the generation configuration using:

gc = GenerationConfig.from_pretrained("google/byt5-base")
gc.penalty_alpha=0.6
gc.top_k=4 
training_args.generation_config = gc

But whenever logs are printed I get warning like:

You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)

and when the configuration files are saved, I get:

{
  "_from_model_config": true,
  "decoder_start_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": 0,
  "transformers_version": "4.29.2"
}

and my generation still has lots of repetitions and bad quality. How can I use the generation config properly? @sgugger

1 Like