Invalidate beam in do_sample mode with LogitsProcessor by setting it to -inf

When using a custom LogitsProcessor I am able to invaldiate a beam by setting it’s score to

scores[beam_idx, :] = -float("inf")  # Set all scores to negative infinity to stop further generation

This works fine when using do_sample=False. However, whenever I do this with sampling enabled, I get a RuntimeError: probability tensor contains either inf, nan or element < 0 .

How do I invalidate all tokens of a beam with do_sample=True ?

If I set this to

scores[beam_idx, :] = -float(1e19) # go down very low but not -inf without overflowing the variable  

I can see the effect, though random tokens are produced.

Edit: Minimal Example

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LogitsProcessorList,
    LogitsProcessor,
    MinLengthLogitsProcessor,
    TopKLogitsWarper,
    TemperatureLogitsWarper,
    BeamSearchScorer,
)

import torch

md = 'EleutherAI/pythia-410m-deduped'
prompt = "Get the milk they said"

tokenizer = AutoTokenizer.from_pretrained(md)
model = AutoModelForCausalLM.from_pretrained(md)
# lets run beam search using 5 beams
num_beams = 5

class CustomProcessor(LogitsProcessor):
  def __call__(self, input_ids, scores):
    for beam_idx in range(scores.shape[0]): # dont invalidate all beams in do_sample=False or your kernel will crash OOM
        scores[beam_idx, :] = -float("inf")  # Set all scores to negative infinity to stop further generation
    print(scores)
    return scores

logits_processor = LogitsProcessorList([CustomProcessor()])
inp = tokenizer(prompt, return_tensors="pt")
input_ids = inp.input_ids

beam_output = model.generate(
    input_ids,
    max_length=50, 
    num_beams=num_beams, 
    do_sample=True,
    temperature=0.89,
    top_k=30,
    early_stopping=True,
    attention_mask=inp.attention_mask,
    remove_invalid_values=True,
    logits_processor=logits_processor,
    pad_token_id=41420,
    eos_token_id=tokenizer.eos_token_id,
)
print('::::::::::::::::::::::')
tokenizer.batch_decode(beam_output, skip_special_tokens=False)[0]