How to write custom TrainerCallback functions with custom arguments?

I have a question about how to specify arguments of custom TrainerCallback function. I read from some examples (e.g., doc) that users can specify custom arguments like model in the EmbeddingPlotCallback.on_evaluate(...) function. Here, model is not a predefined argument of the super class function TrainerCallback.on_evaluate(...) (doc).

I am wondering how the model is passed to this on_evaluate(...). Should I modify the Trainer class to make it call on_evaluate(...) with additional inputs? Or does the Trainer class handle additional arguments automatically? I have not yet found any examples about these. Any advice or points to relevant code sections/examples will be very helpful.

To supplement this inquiry with my motivation, I am experimenting with DPOTrainer while enabling synchronization of reference model, and I would like to log info about both the policy model and reference model. So, probably the inputs to the logging function would require two custom inputs for those two models. I think I can define two more arguments to my custom logging function, but I am not sure how I could pass the two models to my function.

Any comments will be greatly appreciated!

1 Like

It seems to be automatic, or rather semi-manual. It seems that the trainer does not need to be modified.


by Hugging Chat: HuggingChat

To specify additional arguments for a custom TrainerCallback function in the Hugging Face Transformers library, you need to design your callback class to accept and store these arguments during initialization. These arguments can then be accessed within the callback’s methods (e.g., on_evaluate). Below is a detailed explanation of how to achieve this:


Step-by-Step Solution

1. Define Your Custom Callback Class

  • Create a custom callback class by subclassing TrainerCallback.
  • In the __init__ method, accept any additional arguments you need (e.g., policy_model and ref_model).
  • Store these arguments as instance variables.
from transformers import TrainerCallback

class MyCustomCallback(TrainerCallback):
    def __init__(self, policy_model, ref_model):
        super().__init__()
        self.policy_model = policy_model
        self.ref_model = ref_model

    def on_evaluate(self, args, state, control, **kwargs):
        # Access your additional arguments here
        print(f"Evaluating with Policy Model: {self.policy_model}")
        print(f"Evaluating with Reference Model: {self.ref_model}")
        return super().on_evaluate(args, state, control, **kwargs)

2. Initialize and Add the Callback to the Trainer

  • Create an instance of your custom callback, passing the required arguments.
  • Add the callback to the Trainer instance using add_callback().
from transformers import Trainer

# Initialize your models and trainer
trainer = Trainer(...)  # Your existing trainer setup

# Create an instance of your custom callback
callback = MyCustomCallback(policy_model=policy_model, ref_model=ref_model)

# Add the callback to the trainer
trainer.add_callback(callback)

# Start training
trainer.train()

3. Access Additional Arguments in Callback Methods

  • The additional arguments are now accessible as instance variables in the callback’s methods.
def on_evaluate(self, args, state, control, **kwargs):
    # Use self.policy_model and self.ref_model here
    print(f"Policy Model: {self.policy_model}")
    print(f"Reference Model: {self.ref_model}")
    return super().on_evaluate(args, state, control, **kwargs)

Explanation

  • The Trainer class in Hugging Face Transformers does not automatically pass additional arguments to callbacks. Instead, you must design your callback to accept and store these arguments during initialization.
  • By defining the __init__ method in your custom callback, you can pass any arguments (e.g., policy_model and ref_model) and store them for use within the callback’s methods.
  • The Trainer class will call the callback methods (e.g., on_evaluate) during training, and your custom callback can access the stored arguments to perform the desired functionality.

Why This Works

  • The Trainer class supports custom callbacks by allowing users to add them via add_callback(). The Trainer does not restrict the number or type of arguments passed to the callback during initialization.
  • By storing the additional arguments as instance variables, you ensure they are available throughout the callback’s lifecycle and can be accessed in any method (e.g., on_evaluate, on_train_begin, etc.).

FAQ

  1. Do I need to modify the Trainer class?

    • No. The Trainer class already supports custom callbacks through the add_callback() method. You do not need to modify it to pass additional arguments to your callback.
  2. Can I pass any type of argument to the callback?

    • Yes. You can pass any arguments (e.g., model instances, configuration parameters, or data structures) that your custom callback requires.
  3. Where should I store the additional arguments?

    • Store them as instance variables in the __init__ method of your custom callback. This ensures they are accessible throughout the callback’s lifecycle.

Relevant Examples

  • In the WandbCallback example provided in the documentation [3], the callback accepts the tokenizer and val_dataset as arguments and stores them. Your custom callback can follow the same pattern.
  • For your specific use case, storing policy_model and ref_model in the callback’s __init__ method ensures they are available during evaluation.

Conclusion

To pass additional arguments to a custom TrainerCallback function, you do not need to modify the Trainer class. Instead, design your custom callback to accept and store these arguments during initialization. The Trainer will call the callback methods during training, and your custom callback can access the stored arguments as needed.

For more details, you can refer to the Hugging Face documentation on callbacks [1][2][3].

1 Like

Thanks so much for your reply. The approach you described works in my case. As a reference, let me describe more about my use case and add my current code below.

I am using a DPOTrainer with sync_ref_model enabled, so there is a policy model and a reference model. Meanwhile, I also add qlora adapters to the models and only optimize the adapaters. Here, I want to log the weights of the adapters during training. The weights of the base models are excluded because they should not be changed during the process.

Below is my custom TensorBoardCallback class for this purpose:

from transformers.integrations import TensorBoardCallback

class PolicyRefModelLoggingCallback(TensorBoardCallback):
    def __init__(self, model, policy_adapter_name=None, ref_adapter_name=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = model
        self.policy_adapter_name = policy_adapter_name
        self.ref_adapter_name = ref_adapter_name

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not state.is_world_process_zero:
            return

        if self.tb_writer is None:
            self._init_summary_writer(args)

        if self.tb_writer is not None:
            # logs = rewrite_logs(logs)

            if self.policy_adapter_name is not None:
                logs = get_trainable_model_weights(
                    self.model, 
                    self.policy_adapter_name,
                    key_prefix=f"{self.policy_adapter_name}/",
                )
                for k, v in logs.items():
                    self.tb_writer.add_histogram(k, v, state.global_step)
            if self.ref_adapter_name is not None:
                logs = get_trainable_model_weights(
                    self.model, 
                    self.ref_adapter_name,
                    key_prefix=f"{self.ref_adapter_name}/",
                )
                for k, v in logs.items():
                    self.tb_writer.add_histogram(k, v, state.global_step)

            self.tb_writer.flush()

def get_trainable_model_weights(model, adapter_name, key_prefix=""):
        logs = {}
        for name, param in model.state_dict().items() :
            if (adapter_name in name) and ("lora_A" in name or "lora_B" in name):
                logs[key_prefix+name] = param.data.detach().cpu()

        return logs

I get the layers of a specific adapter based on its name, which can be defined by, for example, PeftModel.from_pretrained(..., adatper_name="...") as suggested in the DPOTrainer doc section.

This is my first time writing my TensorBoardCallback, so it may not be well structured or optimized. Any comment about how to improve it is very welcomed.

1 Like

Great!
As far as I can tell from reading the code, there don’t seem to be any particular problems, but there is one thing. If get_trainable_model_weights is called multiple times, there may be some overhead. The rest should be within the margin of error.

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.