I want to accelerate per-sample gradient computations with functorch. How do we compile models from transformers to use with functorch?
Thanks
I want to accelerate per-sample gradient computations with functorch. How do we compile models from transformers to use with functorch?
Thanks
Hi @jpcorb20, were you able to solve this issue? I’m trying to do the same right now, getting batched per-sample gradients using functorch (for BERT, ViT, and ResNet)
@pkadambi hello, unfortunately no, I ended up recomputing per-sample gradients on the side in pure torch, which is far from efficient but my research only required the actual numbers in the end …