I have a question about how to specify arguments of custom TrainerCallback function. I read from some examples (e.g., doc) that users can specify custom arguments like model
in the EmbeddingPlotCallback.on_evaluate(...)
function. Here, model
is not a predefined argument of the super class function TrainerCallback.on_evaluate(...)
(doc).
I am wondering how the model is passed to this on_evaluate(...)
. Should I modify the Trainer class to make it call on_evaluate(...)
with additional inputs? Or does the Trainer class handle additional arguments automatically? I have not yet found any examples about these. Any advice or points to relevant code sections/examples will be very helpful.
To supplement this inquiry with my motivation, I am experimenting with DPOTrainer while enabling synchronization of reference model, and I would like to log info about both the policy model and reference model. So, probably the inputs to the logging function would require two custom inputs for those two models. I think I can define two more arguments to my custom logging function, but I am not sure how I could pass the two models to my function.
Any comments will be greatly appreciated!