How to write custom TrainerCallback functions with custom arguments?

I have a question about how to specify arguments of custom TrainerCallback function. I read from some examples (e.g., doc) that users can specify custom arguments like model in the EmbeddingPlotCallback.on_evaluate(...) function. Here, model is not a predefined argument of the super class function TrainerCallback.on_evaluate(...) (doc).

I am wondering how the model is passed to this on_evaluate(...). Should I modify the Trainer class to make it call on_evaluate(...) with additional inputs? Or does the Trainer class handle additional arguments automatically? I have not yet found any examples about these. Any advice or points to relevant code sections/examples will be very helpful.

To supplement this inquiry with my motivation, I am experimenting with DPOTrainer while enabling synchronization of reference model, and I would like to log info about both the policy model and reference model. So, probably the inputs to the logging function would require two custom inputs for those two models. I think I can define two more arguments to my custom logging function, but I am not sure how I could pass the two models to my function.

Any comments will be greatly appreciated!

1 Like

It seems to be automatic, or rather semi-manual. It seems that the trainer does not need to be modified.


by Hugging Chat: HuggingChat

To specify additional arguments for a custom TrainerCallback function in the Hugging Face Transformers library, you need to design your callback class to accept and store these arguments during initialization. These arguments can then be accessed within the callback’s methods (e.g., on_evaluate). Below is a detailed explanation of how to achieve this:


Step-by-Step Solution

1. Define Your Custom Callback Class

  • Create a custom callback class by subclassing TrainerCallback.
  • In the __init__ method, accept any additional arguments you need (e.g., policy_model and ref_model).
  • Store these arguments as instance variables.
from transformers import TrainerCallback

class MyCustomCallback(TrainerCallback):
    def __init__(self, policy_model, ref_model):
        super().__init__()
        self.policy_model = policy_model
        self.ref_model = ref_model

    def on_evaluate(self, args, state, control, **kwargs):
        # Access your additional arguments here
        print(f"Evaluating with Policy Model: {self.policy_model}")
        print(f"Evaluating with Reference Model: {self.ref_model}")
        return super().on_evaluate(args, state, control, **kwargs)

2. Initialize and Add the Callback to the Trainer

  • Create an instance of your custom callback, passing the required arguments.
  • Add the callback to the Trainer instance using add_callback().
from transformers import Trainer

# Initialize your models and trainer
trainer = Trainer(...)  # Your existing trainer setup

# Create an instance of your custom callback
callback = MyCustomCallback(policy_model=policy_model, ref_model=ref_model)

# Add the callback to the trainer
trainer.add_callback(callback)

# Start training
trainer.train()

3. Access Additional Arguments in Callback Methods

  • The additional arguments are now accessible as instance variables in the callback’s methods.
def on_evaluate(self, args, state, control, **kwargs):
    # Use self.policy_model and self.ref_model here
    print(f"Policy Model: {self.policy_model}")
    print(f"Reference Model: {self.ref_model}")
    return super().on_evaluate(args, state, control, **kwargs)

Explanation

  • The Trainer class in Hugging Face Transformers does not automatically pass additional arguments to callbacks. Instead, you must design your callback to accept and store these arguments during initialization.
  • By defining the __init__ method in your custom callback, you can pass any arguments (e.g., policy_model and ref_model) and store them for use within the callback’s methods.
  • The Trainer class will call the callback methods (e.g., on_evaluate) during training, and your custom callback can access the stored arguments to perform the desired functionality.

Why This Works

  • The Trainer class supports custom callbacks by allowing users to add them via add_callback(). The Trainer does not restrict the number or type of arguments passed to the callback during initialization.
  • By storing the additional arguments as instance variables, you ensure they are available throughout the callback’s lifecycle and can be accessed in any method (e.g., on_evaluate, on_train_begin, etc.).

FAQ

  1. Do I need to modify the Trainer class?

    • No. The Trainer class already supports custom callbacks through the add_callback() method. You do not need to modify it to pass additional arguments to your callback.
  2. Can I pass any type of argument to the callback?

    • Yes. You can pass any arguments (e.g., model instances, configuration parameters, or data structures) that your custom callback requires.
  3. Where should I store the additional arguments?

    • Store them as instance variables in the __init__ method of your custom callback. This ensures they are accessible throughout the callback’s lifecycle.

Relevant Examples

  • In the WandbCallback example provided in the documentation [3], the callback accepts the tokenizer and val_dataset as arguments and stores them. Your custom callback can follow the same pattern.
  • For your specific use case, storing policy_model and ref_model in the callback’s __init__ method ensures they are available during evaluation.

Conclusion

To pass additional arguments to a custom TrainerCallback function, you do not need to modify the Trainer class. Instead, design your custom callback to accept and store these arguments during initialization. The Trainer will call the callback methods during training, and your custom callback can access the stored arguments as needed.

For more details, you can refer to the Hugging Face documentation on callbacks [1][2][3].