This post is related to Whitelist specific tokens in beam search - #4 by 100worte
I see methods such as beam_search()
and sample()
has a parameter logits_processor
, but generate()
does not. As of 4.12.3, generate()
seems to be calling _get_logits_processor()
without any way to pass additional logits processors.
From my belief, we are supposed to call generate()
with parameters instead of any other methods for generation. Is my belief correct? How should I fix this minimally working example below to allow me to use MyCustomLogitsProcessor
? Thank you!
import transformers
import torch
from transformers.generation_logits_process import LogitsProcessor,LogitsProcessorList
from transformers import GPT2Tokenizer, GPT2LMHeadModel
class MyCustomLogitsProcessor(LogitsProcessor):
def __init__(self):
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
return scores # Minimally working
if __name__ == '__main__':
print(transformers.__version__) #4.12.3
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
logits_processor_list = LogitsProcessorList([
MyCustomLogitsProcessor(),
])
generation_output = model.generate(**inputs,
return_dict_in_generate=True,
output_scores=True,
logits_processor=logits_processor_list, # Will break:got multiple values for keyword argument 'logits_processor'
# What should I do here to use MyCustomLogitsProcessor?
)
#Expected: Hello, my dog is cute and icky. I'm not sure if she's a good dog
print(tokenizer.decode(generation_output["sequences"][0], skip_special_tokens=True))