Pass Arguments to custom compute_metrics in Trainer

Hi everyone. I’m passing custom compute metric to the trainer. However I need to provide an additional argument besides the batch. How do I do that?

My custom compute metric function:

def compute_metrics(p, label_list):
...

My trainer:

trainer = Trainer(                                                                                                                                             
            model=model,                                                                                                                                       
            args=training_args,                                                                                                                                
            train_dataset=ds["train"],                                                                                                          
            eval_dataset=ds["val"],                                                                                                             
            tokenizer=tokenizer,                                                                                                                               
            data_collator=data_collator,                                                                                                                       
            compute_metrics=compute_metrics)

How do I pass the argument label_list at the Trainer to my compute_metrics function? I couldn’t find any solutions to that.

Hi, I think this discussion could be useful for you:

In a nutshell

def prepare_compute_metrics(label_list):
    def compute_metrics(p):
        nonlocal label_list
        ...
    return compute_metrics
    
compute_metrics = prepare_compute_metrics(label_list)
1 Like

thank you :slight_smile: