Hi,
I wrote my own BERT model using dm-haiku
(GitHub - deepmind/dm-haiku: JAX-based neural network library) and jax. My model is at least 15x slower than the flax models you have in transformers
. I am trying to see if my jit
calls are misplaced.
While greping the files in examples/flax/text-classification
or in src/transformers
, I don’t find any jit
calls. Where is the jitting
happening? Is flax
handling that?
Thanks,