T5 Finetuning Tips

Hi everyone, I’m mt5-small a go locally with 32k samples of varying length up to 512 tokens. Not super familiar with training this model.
Got 2x12gb rtx.

Previously managed to successfully finetune MarianMT models on this hardware, and the training was quite stable with 32 gradient accumulation steps, AdamW and 5e-7 learning rate for 24 epochs, and it performs well on the task it was finetuned on.

I’ve tried the same settings, except the adjusted learning rate to the advised 3e-4 with mt5-small and I’ve got some concerns, perhaps some of you could help to address

  • training appears to be going much slower
  • why can’t i fit base size on 2x12gb, with sample size of up to 512 tokens
  • gradient norm has ridiculously high values
  • evals metrics have really low values

I’m using the same dataset, and the same training procedure as for MarianMT (except for learning rate). Why are the differences so large?

Here are some code snippets that might be relevant from my tuner class

@valhalla @sshleifer should I instead append the eos token to the end of each label and return the str labels instead of tokens? Where do you get the task prefix?
Is there a list of tax prefixes used by google for mt5? I guess it would be easier if I plugged in to an existing prefix.


    def _preprocess_function(self, examples):
        """
        Tokenize and preprocess a batch of examples for model training.

        """
        # Add task-specific prefix
        prefix = f"<{self.src_key}2{self.mt_key}>"
        inputs = [prefix + text for text in examples[self.src_key]]
        targets = examples[self.mt_key]
        model_inputs = self.tokenizer(inputs, max_length=self.samples_filter_max_tokens, padding='max_length', truncation=True)

        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(targets, max_length=self.samples_filter_max_tokens, padding='max_length', truncation=True)

        # Replace pad token ID with -100 to ignore in loss computation
        labels["input_ids"] = [
            [(label if label != self.tokenizer.pad_token_id else -100) for label in labels_seq]
            for labels_seq in labels["input_ids"]
        ]

        model_inputs['labels'] = labels['input_ids']
        return model_inputs

@moscow25 @valhalla @sshleifer @mrm8488 should I just set adafactor=True in my Seq2SeqTrainingArguments?
What about fp16, anyone had success?
What about multiply_by_parametr_scale=True is this something I have to additionaly configure, or is it a default value?

    def run(self):
        """
        Execute the main training loop for fine-tuning the model.

        """
        self.setup()  # Ensure setup is complete, loads model,  datasets

        # Preprocess datasets
        self.log.info("Tokenizing training dataset...")
        tokenized_train_dataset = self.train_dataset.map(self._preprocess_function, batched=True)

        self.log.info("Tokenizing validation dataset...")
        tokenized_eval_dataset = self.eval_dataset.map(self._preprocess_function, batched=True)

        data_collator = DataCollatorForSeq2Seq(
            tokenizer=self.tokenizer,
            model=self.model,
            label_pad_token_id=-100, 
        )

        training_args = Seq2SeqTrainingArguments(
            output_dir=self.output_dir,
            logging_steps=self.logging_steps,
            save_steps=self.save_steps,
            eval_steps=self.eval_steps,
            eval_strategy="steps",
            predict_with_generate=True,
            report_to="wandb" if not self.is_test_run else [],
            metric_for_best_model=self.metric_for_best_model,
            greater_is_better=True,
            load_best_model_at_end=True,
            save_total_limit=self.save_total_limit,
            learning_rate=self.learning_rate,
            weight_decay=self.weight_decay,
            per_device_train_batch_size=self.per_device_train_batch_size,
            per_device_eval_batch_size=self.per_device_eval_batch_size,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            auto_find_batch_size=self.auto_find_batch_size,
            num_train_epochs=self.num_train_epochs,
            run_name=f"{self.run_name}",
            seed=self.seed,
            fp16=True,  # Enable mixed precision if supported
        )
        trainer = Seq2SeqTrainer(
            model=self.model,
            args=training_args,
            compute_metrics=self._compute_mt_metrics,
            train_dataset=tokenized_train_dataset,
            eval_dataset=tokenized_eval_dataset,
            tokenizer=self.tokenizer,
            data_collator=data_collator,
        )

        trainer.train()

Kind thanks!