Hello,
I was trying to generate text using flax (just as an experiment to see if it works well on a TPU-VM machine). However, no matter how I tried, it always generates the exact same text for a given prompt. This happened both on a TPU-VM as well as local CPU inference.
Here is a short code snippet which demonstrates the problem I encountered:
from transformers import FlaxGPTNeoForCausalLM, AutoTokenizer
model_name = 'EleutherAI/gpt-neo-125M'
model = FlaxGPTNeoForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_text = "Hello there, my name is"
generated_max_length = 50
#Changing the seed value, does not seem to change the outcome
seed = 1001
model.seed = seed
model.config.pad_token_id = model.config.eos_token_id
inputs = tokenizer(prompt_text, return_tensors="jax")
#Changing temperature, top_k and top_p does not seem to change the outcome
outputs = model.generate(
input_ids = inputs["input_ids"],
max_length=generated_max_length,
do_sample=True,
temperature=0.8,
early_stopping=True,
top_k=50,
top_p=0.90)
output_sequence = outputs['sequences'].squeeze(0)
text = tokenizer.decode(output_sequence, clean_up_tokenization_spaces=True)
print(text)
#Always prints:
#Hello there, my name isergus, and I was presented a competition looking for a library keeper with the 31-K--Goods Wallace Leisure library manager on 10-11-08 447-5721. It involves teaching a Royally
I also tried calling jax.random.PRNGKey(seed) which didn’t help as well as other methods such as:
model.top_p = 0.9
model.top_k = 50
jit_generate = jax.jit(model.generate)
#jit_generate( inputs["input_ids"], ....
I assume I’m doing something very wrong, but I was not able to find any example code for generating text with FlaxGPTNeoForCausalLM (I did find examples for training it).
I hope I posted this in the right forum.
Regards,
Doron