Gradient_checkpointing control

Hi community,

I have a basic intuition about how gradient_checkpointing works which only saves activations at some layers and recompute those are not saved during backprop in exchange of memory efficiency.

I wonder how gradient_checkpointing is actually being handled under the hood and if there is any way to control its behavior such as if I want to reduce the extent of checkpointing etc.

Thanks!