Potential bug with beam search + eos_token_id

I want the generation to stop at certain eos_token_ids (e.g. “\n”, “.”). I keep failing at transformers/beam_search.py at c1f85598eb637bb7b21f8c3b15886aaed40ca42e · huggingface/transformers · GitHub because there are not enough non-eos-tokens in next_tokens.

  1. Could someone explain to me what does “Make sure {next_tokens[batch_idx]} are corrected.” mean?
  2. I tried specifying early_stopping=True in my GenerationConfig but it doesn’t help.

I think in this case, if all beams are saying “stop”, we should just stop? Or keep track of the beams that are not stopping. Is there any way to get around this error? In fact, if the code just executes one more round, with early_stopping=True transformers/beam_search.py at c1f85598eb637bb7b21f8c3b15886aaed40ca42e · huggingface/transformers · GitHub will just mark itself as finished. I don’t quite understand why an error is thrown here.

To replicate the issue:

import torch
import transformers
device='cuda:0'
input_ids = torch.tensor([[ 2, 33683, 209, 1142, 35, 50118, 1864, 35, 96, 3430, 10, 258, 219, 73, 17143, 324, 16, 10, 116, 50118, 250, 35, 446, 50118, 1864, 35, 3394, 56, 10, 1510, 29, 440, 112, 478, 19, 22426, 370, 404, 2306, 116, 50118, 250, 35]])
model = transformers.OPTForCausalLM.from_pretrained("facebook/opt-6.7b", torch_dtype=torch.float16)
model.to(device)
generation_config = {"bad_words_ids": [[ 2, 1864, 35 ] ], "early_stopping": True, "eos_token_id": [ 50118, 6, 4 ], "max_new_tokens": 256, "pad_token_id": 2}
generation_config = transformers.GenerationConfig(**generation_config)
model.generate(input_ids.to(device), attention_mask=torch.ones_like(input_ids).to(device),
               num_beams=5, num_return_sequences=1, 
               do_sample=False, generation_config=generation_config)

I noticed, however, if we don’t pass attention_mask, no error will be raised for this input. The behavior is reversed with the following input (i.e. error without attention_mask and no error with)

input_ids = torch.tensor([[ 2, 33683, 209, 1142, 35, 50118, 1864, 35, 96, 3430, 10, 258, 219, 73, 17143, 324, 16, 10, 116, 50118, 250, 35, 446, 50118, 1864, 35, 3394, 21, 5, 313, 639, 20, 11055, 20614, 2258, 116, 50118, 250, 35]])

Hey, I have the same issue, with the same inputs, with only difference is mode == BEAM_SAMPLE. Had you luck mitigation it?
@zl7