Evaluate Whisper on two different datasets

Hi, I am fine-tuning Whisper on the Spanish common voice training split and evaluating on an interleaved dataset (Spanish and English validation splits).

training_dataset -----> Spanish train split only
validation_dataset ------> Both English and Spanish validation splits

After training, I would like to get english_WER and epanish_WER through the compute_metrics function. This is what I did:

During the interleaving, I kept the column called locale.

#1) I filtered the English and Spanish eval splits
english_validation = vectorized_datasets["validation"].filter(lambda i: i["locale"] == "en")
spanish_validation = vectorized_datasets["validation"].filter(lambda i: i["locale"] == "es")

#2) I use the following callback to compute metrics
    class ComputeMetricsCallback(TrainerCallback):
        def on_evaluate(self, args, state, eval_dataloader, control, **kwargs):
                if isinstance(eval_dataloader.dataset, english_validation):
                        def compute_metrics(pred):
                               pred_ids = pred.predictions
                              pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
                              pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
                              # we do not want to group tokens when computing the metrics
                              label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
                              en_wer = metric.compute(predictions=pred_str, references=label_str)
                              return {"en_wer": en_wer}
                elif isinstance(eval_dataloader.dataset, spanish_validation):
                         def compute_metrics(pred):
                               pred_ids = pred.predictions
                              pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
                              pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
                              # we do not want to group tokens when computing the metrics
                              label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
                              es_wer = metric.compute(predictions=pred_str, references=label_str)
                              return {"es_wer": es_wer}

#3) I pass this to the trainer, but I do not know if I should keep `compute_metrics` and `eval_dataset` in the trainer

trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
        eval_dataset=vectorized_datasets["validation"] if training_args.do_eval else None,
        tokenizer=feature_extractor,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
        callbacks=[ComputeMetricsCallback()],
    )

#4) the last step is when doing (do_eval):

    english_results = {}
    spanish_results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        # EVAL ENGLISH
        english_metrics = trainer.evaluate(
            eval_dataset = english_validation,
            metric_key_prefix="validation",
            max_length=training_args.generation_max_length,
            num_beams=training_args.generation_num_beams,
        )
        max_eval_samples = (
            data_args.max_eval_samples if data_args.max_eval_samples is not None else len(english_validation["validation"])
        )
        english_metrics["eval_samples"] = min(max_eval_samples, len(english_validation["validation"]))
        print("------> printing english metrics", english_metrics)

        trainer.log_metrics("validation", english_metrics)
        trainer.save_metrics("validation", english_metrics)
        # EVAL SPANISH
        spanish_metrics = trainer.evaluate(
            eval_dataset = spanish_validation,
            metric_key_prefix="validation",
            max_length=training_args.generation_max_length,
            num_beams=training_args.generation_num_beams,
        )
        max_eval_samples = (
            data_args.max_eval_samples if data_args.max_eval_samples is not None else len(spanish_validation["validation"])
        )
        spanish_metrics["eval_samples"] = min(max_eval_samples, len(spanish_validation["validation"]))
        print("------> printing spnaish metrics", spanish_metrics)
        trainer.log_metrics("validation", spanish_metrics)
        trainer.save_metrics("validation", spanish_metrics)

Is this approach right? Or if there is a way to evaluate Whisper on two eval splits and compute WER for each, please let know how to do it. Thanks.