How to ensure my custom Trainer is using my custom TrainerState and TrainerControl?

I want to implement some custom functionality into my training/evaluation loop and therefore have implemented my own TrainerControl and TrainerState objects that inherit from their respective HF objects. My Trainer also inherits from HF Trainer.

I’m currently initializing my state and control objects inside the Trainer as follows:

class CustomTrainer(Trainer):
    def __init__(
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.state = CustomTrainerState(
            custom_arg=self.args.custom_arg
        )
        self.control = CustomTrainerControl()

When I place a PDB breakpoint inside of __init__ the state and control objects are properly being overridden (since they’re placed after the super() call). However, inside the actual training loop it seems like the Trainer is still defaulting back to the original HF TrainerState and therefore doesn’t have the custom argument that I need it to have.

What am I doing wrong? What is the correct approach to this?

Thanks in advance!

Edit

Doing some more investigating, it seems like the problem is that the TrainerState isn’t being properly passed to the callback.

My custom callback looks like the following:

class CustomCallback(TrainerCallback):
    def on_step_end(
        self,
        args,
        state,
        control,
        **kwargs,
    ):
        # Some stuff.

When I place a PDB breakpoint inside of the on_step_end method and check control and state, control is properly CustomTrainerControl but state is TrainerState.

I think I found the problem. The _inner_training_loop method of Trainer reinitializes self.state to be TrainerState: transformers/src/transformers/trainer.py at b7672826cad31e30319487af876e608d8af7d37b · huggingface/transformers · GitHub

I guess I’ll have to override this method.