FlaxGPTNeoForCausalLM generates the same text regardless of seed, temperature, top_k and top_p values

Hello,

I was trying to generate text using flax (just as an experiment to see if it works well on a TPU-VM machine). However, no matter how I tried, it always generates the exact same text for a given prompt. This happened both on a TPU-VM as well as local CPU inference.

Here is a short code snippet which demonstrates the problem I encountered:

from transformers import FlaxGPTNeoForCausalLM, AutoTokenizer

model_name = 'EleutherAI/gpt-neo-125M'
model = FlaxGPTNeoForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt_text = "Hello there, my name is"
generated_max_length = 50

#Changing the seed value, does not seem to change the outcome
seed = 1001
model.seed = seed
model.config.pad_token_id = model.config.eos_token_id

inputs = tokenizer(prompt_text, return_tensors="jax")

#Changing temperature, top_k and top_p does not seem to change the outcome
outputs = model.generate(
    input_ids = inputs["input_ids"], 
    max_length=generated_max_length, 
    do_sample=True,
    temperature=0.8,
    early_stopping=True,
    top_k=50,
    top_p=0.90)

output_sequence = outputs['sequences'].squeeze(0)
text = tokenizer.decode(output_sequence, clean_up_tokenization_spaces=True)

print(text)

#Always prints:
#Hello there, my name isergus, and I was presented a competition looking for a library keeper with the 31-K--Goods Wallace Leisure library manager on 10-11-08 447-5721. It involves teaching a Royally

I also tried calling jax.random.PRNGKey(seed) which didn’t help as well as other methods such as:

model.top_p = 0.9
model.top_k = 50

jit_generate = jax.jit(model.generate)
#jit_generate( inputs["input_ids"],  ....

I assume I’m doing something very wrong, but I was not able to find any example code for generating text with FlaxGPTNeoForCausalLM (I did find examples for training it).

I hope I posted this in the right forum.
Regards,
Doron

1 Like