I want the generation to stop at certain eos_token_ids (e.g. “\n”, “.”). I keep failing at transformers/beam_search.py at c1f85598eb637bb7b21f8c3b15886aaed40ca42e · huggingface/transformers · GitHub because there are not enough non-eos-tokens in next_tokens
.
- Could someone explain to me what does “Make sure {next_tokens[batch_idx]} are corrected.” mean?
- I tried specifying
early_stopping=True
in my GenerationConfig but it doesn’t help.
I think in this case, if all beams are saying “stop”, we should just stop? Or keep track of the beams that are not stopping. Is there any way to get around this error? In fact, if the code just executes one more round, with early_stopping=True
transformers/beam_search.py at c1f85598eb637bb7b21f8c3b15886aaed40ca42e · huggingface/transformers · GitHub will just mark itself as finished. I don’t quite understand why an error is thrown here.
To replicate the issue:
import torch
import transformers
device='cuda:0'
input_ids = torch.tensor([[ 2, 33683, 209, 1142, 35, 50118, 1864, 35, 96, 3430, 10, 258, 219, 73, 17143, 324, 16, 10, 116, 50118, 250, 35, 446, 50118, 1864, 35, 3394, 56, 10, 1510, 29, 440, 112, 478, 19, 22426, 370, 404, 2306, 116, 50118, 250, 35]])
model = transformers.OPTForCausalLM.from_pretrained("facebook/opt-6.7b", torch_dtype=torch.float16)
model.to(device)
generation_config = {"bad_words_ids": [[ 2, 1864, 35 ] ], "early_stopping": True, "eos_token_id": [ 50118, 6, 4 ], "max_new_tokens": 256, "pad_token_id": 2}
generation_config = transformers.GenerationConfig(**generation_config)
model.generate(input_ids.to(device), attention_mask=torch.ones_like(input_ids).to(device),
num_beams=5, num_return_sequences=1,
do_sample=False, generation_config=generation_config)
I noticed, however, if we don’t pass attention_mask
, no error will be raised for this input. The behavior is reversed with the following input (i.e. error without attention_mask
and no error with)
input_ids = torch.tensor([[ 2, 33683, 209, 1142, 35, 50118, 1864, 35, 96, 3430, 10, 258, 219, 73, 17143, 324, 16, 10, 116, 50118, 250, 35, 446, 50118, 1864, 35, 3394, 21, 5, 313, 639, 20, 11055, 20614, 2258, 116, 50118, 250, 35]])