Jax and flax version used for the new gemma models

I get following error when I try to run Gemma model:

Traceback (most recent call last):
  File "/home/ruvaidya/hf/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 266, in __getattr__
    return self.data[item]
KeyError: 'shape'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "gemma_jax.py", line 17, in <module>
    output = model.generate(inputs, params=params, max_new_tokens=20, do_sample=False)
  File "/home/ruvaidya/hf/lib/python3.8/site-packages/transformers/generation/flax_utils.py", line 366, in generate
    batch_size = input_ids.shape[0]
  File "/home/ruvaidya/hf/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 268, in __getattr__
    raise AttributeError
AttributeError

are there any specific jax/ flax versions?

source:

import jax.numpy as jnp
from transformers import AutoTokenizer, FlaxGemmaForCausalLM

model_id = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = "left"

model, params = FlaxGemmaForCausalLM.from_pretrained(
        model_id,
        dtype=jnp.bfloat16,
        revision="flax",
        _do_init=False,
)

inputs = tokenizer("Valencia and Málaga are", return_tensors="np", padding=True)
output = model.generate(inputs, params=params, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output.sequences, skip_special_tokens=True)

Please try to modify the model.generate() call as follows (inputs → **inputs)

output = model.generate(**inputs, params=params, max_new_tokens=20, do_sample=False)