Gradient checkpointing without training

I have a LlamaForCausalLM model. I want to do a single run of backprop on a single sample (one forward pass, one backward pass) and record all the gradients that are computed in the process. I do not want to actually update the model weights- I just want to record the gradients. The model is pretty big and I only have a single GPU, so to be able to do this I need to use gradient checkpointing. Is there a way to use a Trainer to accomplish this? Thanks!