Dynamically resizing input for Huggingface's generate()

I would like to call Huggingface’s GenerateMixin’s generate method, model.generate, with some model such as Llama2, where on every pass, the size of input_ids does not change.

For instance, if I had an input of 3 tokens and wanted to generate up to max_length=7, then:>

Before any generation: [8 5 1 0 0 0 0]
After 1 pass: [8 5 1 3 0 0 0]
After 2 passes: [8 5 1 3 7 0 0]
…etc

is the behavior I am looking for. For some more context, I am compiling the model in a scenario where we don’t support dynamic shapes for the input, so the input size would need to remain the same on every pass/output token generated in model.generate().

Here is a script I have that works, but uses a custom loop to replace the padded token after each generated token. In this case, the input has been padded to max_length initially.

import numpy as np
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LlamaConfig,
)
max_seq_length = 20
model_name = "Maykeye/TinyLLama-v0"
model_config = LlamaConfig.from_pretrained(model_name)
model_config.use_cache = False
model_config.pad_token_id = model_config.eos_token_id
model_config.max_length = max_seq_length
model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_to_max_length = True
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenize = lambda x: tokenizer(
    x,
    return_tensors="pt",
    padding="max_length",
    max_length=max_seq_length,
    truncation=True,
)

compiled_model = torch.compile(model)

text = "My cat is"
text_len = len(text.split())

for _ in range(max_seq_length - len(text.split())):
    encoded_input = tokenize(text)
    result = compiled_model(**encoded_input)["logits"]

    # Print the next word
    logits = result[0, text_len, :]
    softmax = torch.nn.Softmax(dim=-1)(logits)
    argmax = torch.argmax(softmax, axis=-1)
    word = tokenizer.decode(argmax)
    text_len += 1
    text += " " + word
    
print(text)

What I want to run instead is (with input padded to max length):

generated_ids = compiled_model.generate(
    encoded_input
)

After looking through generate() in GenerateMixin, I notice that in generate_no_beam_search (or beam search), generate will append to the input_ids, with no options to replace instead as far as I can tell. At line 557 in generate:

# add token and increase length by one
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)

I wanted to ask if it is intentional that generate does not support replacing padded tokens in-place, and

  1. Is it intentional that generate() doesn’t support replacing padded tokens in-place? If so, why? Is this use case uncommon?
  2. If not intentional, will this functionality be supported in the future? When can we expect it?

Mainly, I would like to be able to use generate() with the example I showed above, so any suggestions would be greatly appreciated. It would have to be a solution that does not require me to change the source code of generate() directly, re-write any parts of generate() in my own library, etc.