I was a little confused about the Callbacks executed when training.
I am doing gradients editing before parameter updates, I was about to customize the on_step_end()
function, is this the right way to do this?
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
for n, p in kwargs['model'].named_parameters():
if p.requires_grad and p.grad is not None:
safe_set_full_optimizer_state(kwargs['optimizer'], self.current_grads[n])