I would like to call Huggingface’s GenerateMixin’s generate method, model.generate, with some model such as Llama2, where on every pass, the size of input_ids does not change.
For instance, if I had an input of 3 tokens and wanted to generate up to max_length=7, then:>
Before any generation: [8 5 1 0 0 0 0]
After 1 pass: [8 5 1 3 0 0 0]
After 2 passes: [8 5 1 3 7 0 0]
…etc
is the behavior I am looking for. For some more context, I am compiling the model in a scenario where we don’t support dynamic shapes for the input, so the input size would need to remain the same on every pass/output token generated in model.generate().
Here is a script I have that works, but uses a custom loop to replace the padded token after each generated token. In this case, the input has been padded to max_length initially.
import numpy as np
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaConfig,
)
max_seq_length = 20
model_name = "Maykeye/TinyLLama-v0"
model_config = LlamaConfig.from_pretrained(model_name)
model_config.use_cache = False
model_config.pad_token_id = model_config.eos_token_id
model_config.max_length = max_seq_length
model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_to_max_length = True
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenize = lambda x: tokenizer(
x,
return_tensors="pt",
padding="max_length",
max_length=max_seq_length,
truncation=True,
)
compiled_model = torch.compile(model)
text = "My cat is"
text_len = len(text.split())
for _ in range(max_seq_length - len(text.split())):
encoded_input = tokenize(text)
result = compiled_model(**encoded_input)["logits"]
# Print the next word
logits = result[0, text_len, :]
softmax = torch.nn.Softmax(dim=-1)(logits)
argmax = torch.argmax(softmax, axis=-1)
word = tokenizer.decode(argmax)
text_len += 1
text += " " + word
print(text)
What I want to run instead is (with input padded to max length):
generated_ids = compiled_model.generate(
encoded_input
)
After looking through generate() in GenerateMixin, I notice that in generate_no_beam_search (or beam search), generate will append to the input_ids, with no options to replace instead as far as I can tell. At line 557 in generate:
# add token and increase length by one
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
I wanted to ask if it is intentional that generate does not support replacing padded tokens in-place, and
- Is it intentional that
generate()
doesn’t support replacing padded tokens in-place? If so, why? Is this use case uncommon? - If not intentional, will this functionality be supported in the future? When can we expect it?
Mainly, I would like to be able to use generate() with the example I showed above, so any suggestions would be greatly appreciated. It would have to be a solution that does not require me to change the source code of generate() directly, re-write any parts of generate() in my own library, etc.