When using a custom LogitsProcessor
I am able to invaldiate a beam by setting it’s score to
scores[beam_idx, :] = -float("inf") # Set all scores to negative infinity to stop further generation
This works fine when using do_sample=False
. However, whenever I do this with sampling enabled, I get a RuntimeError: probability tensor contains either
inf,
nan or element < 0
.
How do I invalidate all tokens of a beam with do_sample=True
?
If I set this to
scores[beam_idx, :] = -float(1e19) # go down very low but not -inf without overflowing the variable
I can see the effect, though random tokens are produced.
Edit: Minimal Example
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LogitsProcessorList,
LogitsProcessor,
MinLengthLogitsProcessor,
TopKLogitsWarper,
TemperatureLogitsWarper,
BeamSearchScorer,
)
import torch
md = 'EleutherAI/pythia-410m-deduped'
prompt = "Get the milk they said"
tokenizer = AutoTokenizer.from_pretrained(md)
model = AutoModelForCausalLM.from_pretrained(md)
# lets run beam search using 5 beams
num_beams = 5
class CustomProcessor(LogitsProcessor):
def __call__(self, input_ids, scores):
for beam_idx in range(scores.shape[0]): # dont invalidate all beams in do_sample=False or your kernel will crash OOM
scores[beam_idx, :] = -float("inf") # Set all scores to negative infinity to stop further generation
print(scores)
return scores
logits_processor = LogitsProcessorList([CustomProcessor()])
inp = tokenizer(prompt, return_tensors="pt")
input_ids = inp.input_ids
beam_output = model.generate(
input_ids,
max_length=50,
num_beams=num_beams,
do_sample=True,
temperature=0.89,
top_k=30,
early_stopping=True,
attention_mask=inp.attention_mask,
remove_invalid_values=True,
logits_processor=logits_processor,
pad_token_id=41420,
eos_token_id=tokenizer.eos_token_id,
)
print('::::::::::::::::::::::')
tokenizer.batch_decode(beam_output, skip_special_tokens=False)[0]