Hi,
I wrote my own BERT model using dm-haiku (GitHub - deepmind/dm-haiku: JAX-based neural network library) and jax. My model is at least 15x slower than the flax models you have in transformers. I am trying to see if my jit calls are misplaced.
While greping the files in examples/flax/text-classification or in src/transformers, I don’t find any jit calls. Where is the jitting happening? Is flax handling that?
Thanks,