When I make a generate() call with GPT2 or GPT-Neo pipelines, and use a temperature value of 0, I get an error. Why? It doesn’t happen for a bunch of other generate models, and a temperature value of 0 should be allowed.
This is a minimally reproducible example:
from transformers import pipeline, set_seed
generator = pipeline('text-generation', model='gpt2')
set_seed(42)
generator("Hello, I'm a language model,", temperature=0, max_length=30, num_return_sequences=1)
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/dummy_user/miniconda3/envs/dummy_env/lib/python3.7/site-packages/transformers/pipelines/text_generation.py", line 210, in __call__
return super().__call__(text_inputs, **kwargs)
File "/home/dummy_user/miniconda3/envs/dummy_env/lib/python3.7/site-packages/transformers/pipelines/base.py", line 1084, in __call__
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
File "/home/dummy_user/miniconda3/envs/dummy_env/lib/python3.7/site-packages/transformers/pipelines/base.py", line 1091, in run_single
model_outputs = self.forward(model_inputs, **forward_params)
File "/home/dummy_user/miniconda3/envs/dummy_env/lib/python3.7/site-packages/transformers/pipelines/base.py", line 992, in forward
model_outputs = self._forward(model_inputs, **forward_params)
File "/home/dummy_user/miniconda3/envs/dummy_env/lib/python3.7/site-packages/transformers/pipelines/text_generation.py", line 252, in _forward
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
File "/home/dummy_user/miniconda3/envs/dummy_env/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/dummy_user/miniconda3/envs/dummy_env/lib/python3.7/site-packages/transformers/generation/utils.py", line 1426, in generate
logits_warper = self._get_logits_warper(generation_config)
File "/home/dummy_user/miniconda3/envs/dummy_env/lib/python3.7/site-packages/transformers/generation/utils.py", line 755, in _get_logits_warper
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
File "/home/dummy_user/miniconda3/envs/dummy_env/lib/python3.7/site-packages/transformers/generation/logits_process.py", line 172, in __init__
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
ValueError: `temperature` has to be a strictly positive float, but is 0
I trace this to TemperatureLogitsWarper
in transformers.generation_logits_process — transformers 4.0.0 documentation