What model-pairs are supported by the assistant decoding generation in Huggingface AutoModelForCausalLM?

The assistant decoding model as described in Assisted Generation: a new direction toward low-latency text generation is implemented in Generate: Add assisted generation by gante 路 Pull Request #22211 路 huggingface/transformers 路 GitHub

Q Part 1. What model-pairings are known to be supported by the model.generate(..., assistant_model='') feature?

Q Part 2. Does it work for decoder-only model too? Anyone tried any pairs of decoder-only models available on the huggingface hub?


The assumption for the assistant decoding model are:

  • the tokenizer must be the same for assistant and main model
  • the model is supported by AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = 'EleutherAI/pythia-1.4b-deduped'
assistant = 'EleutherAI/pythia-160m-deduped'

tokenizer = AutoTokenizer.from_pretrained(checkpoint) #, bos_token_id=101, eos_token_id=102)
model = AutoModelForCausalLM.from_pretrained(checkpoint) #, bos_token_id=101, eos_token_id=102)

assistant_model = AutoModelForCausalLM.from_pretrained(assistant)

tokenized_inputs = tokenizer("Alice and Bob", return_tensors="pt")

outputs = model.generate(**tokenized_inputs, assistant_model=assistant_model)

tokenizer.batch_decode(outputs, skip_special_tokens=True)

I鈥檝e tried the following and this works:

  • EleutherAI/pythia-1.4b-deduped + EleutherAI/pythia-160m-deduped

But these didn鈥檛:

  • google-bert/bert-large-uncased + google-bert/bert-base-uncased (also had to add , bos_token_id=101, eos_token_id=102) to the model and/or tokenizer initialization to avoid None type when assistant model is scoping down the vocabulary)
  • FacebookAI/xlm-roberta-large + FacebookAI/xlm-roberta-base (ended up with TypeError: object of type 'NoneType' has no len() error when looking for candidate generation)

Also asked on python - What model-pairs are supported by the assistant decoding generation in Huggingface AutoModelForCausalLM? - Stack Overflow and Generate: Add assisted generation by gante 路 Pull Request #22211 路 huggingface/transformers 路 GitHub