I get following error when I try to run Gemma model:
Traceback (most recent call last):
File "/home/ruvaidya/hf/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 266, in __getattr__
return self.data[item]
KeyError: 'shape'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "gemma_jax.py", line 17, in <module>
output = model.generate(inputs, params=params, max_new_tokens=20, do_sample=False)
File "/home/ruvaidya/hf/lib/python3.8/site-packages/transformers/generation/flax_utils.py", line 366, in generate
batch_size = input_ids.shape[0]
File "/home/ruvaidya/hf/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 268, in __getattr__
raise AttributeError
AttributeError
are there any specific jax/ flax versions?
source:
import jax.numpy as jnp
from transformers import AutoTokenizer, FlaxGemmaForCausalLM
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = "left"
model, params = FlaxGemmaForCausalLM.from_pretrained(
model_id,
dtype=jnp.bfloat16,
revision="flax",
_do_init=False,
)
inputs = tokenizer("Valencia and Málaga are", return_tensors="np", padding=True)
output = model.generate(inputs, params=params, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output.sequences, skip_special_tokens=True)