How do I construct a function to inference?

I trained Flax model for the sequence classification well.
After that, I wrote the code for the inference function for the random single text.
But It is very slower than the PyTorch, Tensorflow model.
The code for the inference is below,

import jax
from transformers import AutoTokenizer, FlaxBertForSequenceClassification

tokenizer =  AutoTokenizer.from_pretrained("<model_name_or_path>")
model = FlaxBertForSequenceClassification.from_pretrained("<model_name_or_path>", from_pt=True)

def inference(text:str):
    tokenized = tokenizer(text, return_tensors="jax", truncation=True, max_length=512)
    return model(**tokenized)

(it takes about 2 sec on timit test but Pytorch, TensorFlow take about 400ms)
(I also try to make it as jitted, but it also takes 600ms(below))

def inference(text:str):
    return model(tokenized).logits

tokenized = tokenizer(TEXT, return_tensors="jax", truncation=True, max_length=512)['input_ids']
jitted = jax.jit(inference)

If there is a way to enhance inference performance, please give me some wisdom…!
(What I want is the fastest way to inference (not training))