Where are the jax jit annotations in flax models?

Hi,

I wrote my own BERT model using dm-haiku (GitHub - deepmind/dm-haiku: JAX-based neural network library) and jax. My model is at least 15x slower than the flax models you have in transformers. I am trying to see if my jit calls are misplaced.

While greping the files in examples/flax/text-classification or in src/transformers, I don’t find any jit calls. Where is the jitting happening? Is flax handling that?

Thanks,

1 Like