Use custom LogitsProcessor in `model.generate()`

This post is related to Whitelist specific tokens in beam search - #4 by 100worte

I see methods such as beam_search() and sample() has a parameter logits_processor, but generate() does not. As of 4.12.3, generate() seems to be calling _get_logits_processor() without any way to pass additional logits processors.

From my belief, we are supposed to call generate() with parameters instead of any other methods for generation. Is my belief correct? How should I fix this minimally working example below to allow me to use MyCustomLogitsProcessor? Thank you!

import transformers
import torch
from transformers.generation_logits_process import LogitsProcessor,LogitsProcessorList
from transformers import GPT2Tokenizer, GPT2LMHeadModel

class MyCustomLogitsProcessor(LogitsProcessor):
    def __init__(self):
        pass

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        return scores # Minimally working

if __name__ == '__main__':
    print(transformers.__version__) #4.12.3
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model = GPT2LMHeadModel.from_pretrained('gpt2')
    inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
    logits_processor_list = LogitsProcessorList([
        MyCustomLogitsProcessor(),
    ])
    generation_output = model.generate(**inputs,
                                       return_dict_in_generate=True,
                                       output_scores=True,
                                       logits_processor=logits_processor_list, # Will break:got multiple values for keyword argument 'logits_processor'
                                       # What should I do here to use MyCustomLogitsProcessor?
                                       )

    #Expected: Hello, my dog is cute and icky. I'm not sure if she's a good dog
    print(tokenizer.decode(generation_output["sequences"][0], skip_special_tokens=True))

As it turns out, you cannot add a custom logits processor list to the model.generate(...) call. You need to use your own beam scorer… Similiar to this piece of code I had lying around from a research project.

  bad_words_t = bad_words_ids
  if extra_bad_words is not None:
    bad_words_t += extra_bad_words
  model_out=None
  if horizon is None:
    model_out = model.generate(input_ids = ids['input_ids'],\
                               max_length=max_length, num_beams=beams,\
                               no_repeat_ngram_size=5, bad_words_ids=bad_words_t, repetition_penalty=repetition_penalty)[0]
  else:
    horizon_ids = tokenizer(horizon, return_tensors="pt")['input_ids'].cuda()
    input_ids = ids["input_ids"]
    model.config.max_length = max_length
    # instantiate logits processors
    logits_processor = LogitsProcessorList([
        MinLengthLogitsProcessor(ids['input_ids'].shape[1], model.config.eos_token_id),
        NoRepeatNGramLogitsProcessor(5),
        NoBadWordsLogitsProcessor(bad_words_t, eos_token_id=model.config.eos_token_id),
        HorizonRepetitionPenalty(penalty=horizon_penalty, horizon=horizon_ids, horizon_exclusive=True),
        RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
    ])
    stopping_criteria = StoppingCriteriaList([
        MaxLengthCriteria(max_length=max_length),
    ])
    model_kwargs={
        "attention_mask":ids['attention_mask'],
        "use_cache":True,
    }
    with torch.no_grad():
      model_out = model.greedy_search(
          input_ids=ids["input_ids"], logits_processor=logits_processor,\
          stopping_criteria=stopping_criteria)[0]
    
  return tokenizer.decode(model_out)

I think if the devs are willing to add the ability to pass a custom logits processor to the generate function, it would be a great addition.

2 Likes