I was trying to generate text using flax (just as an experiment to see if it works well on a TPU-VM machine). However, no matter how I tried, it always generates the exact same text for a given prompt. This happened both on a TPU-VM as well as local CPU inference.
Here is a short code snippet which demonstrates the problem I encountered:
from transformers import FlaxGPTNeoForCausalLM, AutoTokenizer model_name = 'EleutherAI/gpt-neo-125M' model = FlaxGPTNeoForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) prompt_text = "Hello there, my name is" generated_max_length = 50 #Changing the seed value, does not seem to change the outcome seed = 1001 model.seed = seed model.config.pad_token_id = model.config.eos_token_id inputs = tokenizer(prompt_text, return_tensors="jax") #Changing temperature, top_k and top_p does not seem to change the outcome outputs = model.generate( input_ids = inputs["input_ids"], max_length=generated_max_length, do_sample=True, temperature=0.8, early_stopping=True, top_k=50, top_p=0.90) output_sequence = outputs['sequences'].squeeze(0) text = tokenizer.decode(output_sequence, clean_up_tokenization_spaces=True) print(text) #Always prints: #Hello there, my name isergus, and I was presented a competition looking for a library keeper with the 31-K--Goods Wallace Leisure library manager on 10-11-08 447-5721. It involves teaching a Royally
I also tried calling jax.random.PRNGKey(seed) which didn’t help as well as other methods such as:
model.top_p = 0.9 model.top_k = 50 jit_generate = jax.jit(model.generate) #jit_generate( inputs["input_ids"], ....
I assume I’m doing something very wrong, but I was not able to find any example code for generating text with FlaxGPTNeoForCausalLM (I did find examples for training it).
I hope I posted this in the right forum.