How to extract gradient during training in pytorch with Trainer module?


I am quite familiar overall with the Trainer module and the models. Yet, it is not perfectly clear to me how to customize it to get gradient metrics like the norm by layer. What would be the best way?

Thanks in advance for your help!



I am having the same issue. I tried creating a custom callback to log gradients to a json file, however the on_step_end hook is called after model.zero_grad in the training loop, which prevents logging any statistics on the gradients.

Do you have any idea on how to do it differently?

For reference, here is the code for my callback :

class GradientsCallback(TrainerCallback):
    def __init__(self, norm_type: float = 2.0):
        self.norm_type = float(norm_type)

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if control.should_log:
            model = kwargs["model"]
            grads = {
                n: for n, p in model.named_parameters() if p.grad is not None

            gradient_logging_file = os.path.join(args.logging_dir, "gradient_norms.json")

                data = json.load(open(gradient_logging_file, "r"))
                data = {}

            with open(gradient_logging_file, "w") as f:
                data[state.global_step] = grads
                json.dump(data, f)



Unfortunately, I had to re-compute the gradients to log them via the callbacks, which is of course sub-optimal. It doesn’t seem to be possible unless you overwrite some parts of the Trainer code. In my case, I needed the per-sample gradients (Per-sample-gradients — functorch 1.13 documentation), which is not directly available.

I see, thanks for your answer.

I still feel like it would be helpful to have at least access to the averaged gradients (i.e., for instance for debugging purposes. I think my solution will be to report to Weights & Biases (instead of tensorboard) which logs gradient histograms.

Hi, I got the same problem and I modified the file, in my case was located at: /usr/local/lib/python3.9/dist-packages/transformers/

Here is my new, I edited lines 37 and 1884 to plot gradients each 1000 steps

Save a copy before, because this file always plot the gradients when train method is called.