Functorch with transformers

I want to accelerate per-sample gradient computations with functorch. How do we compile models from transformers to use with functorch?

Thanks

1 Like

Hi @jpcorb20, were you able to solve this issue? I’m trying to do the same right now, getting batched per-sample gradients using functorch (for BERT, ViT, and ResNet)

@pkadambi hello, unfortunately no, I ended up recomputing per-sample gradients on the side in pure torch, which is far from efficient but my research only required the actual numbers in the end …