Avoid recalculating hidden states between generate calls?

I have a 100-token prompt which I use to generate some text. This will take X milliseconds to achieve.
My next generate call involves the same 100 tokens as the previous call plus 3 more tokens.

To speed up the response time, I want to avoid re-calculating the hidden states for the past 100 tokens (and only process the new 3 tokens).

How can I achieve that? I can not find my way with the docs.

Thank you

You will basically need to smuggle the hidden state between generate() invocations. The hidden state is also called past_key_values in various parts of the code. I ended up having to take over the generation loop and slightly patch Transformers in order to do this. In very rough Python, your modification to the generation loop will look something like this:

def generate(model, old_prompt, added, past_state):
    ...
    # catch up decoding new chars since past_state
    prompt = old_prompt
    for new_token in added:
        prompt += new_token
        input = model.prepare_inputs_for_generation(prompt, **past_state)
        output = model(**input, return_dict=True)
        past_state['past_key_values'] = output['past_key_values']
    # resume generation from here after decoding is done
    while True:
        input = model.prepare_inputs_for_generation(prompt, **past_state)
        output = model(**input, return_dict=True)
        ...
    ...

I basically copied sample() from transformers/generation/utils.py and modified it in the way I just described. I know it’s far from complete code, but it should give you an idea as to what is required. I don’t doubt there’s an easier way to do this as well, but it didn’t seem like the Transformers generate() function allows passing in the initial hidden state (else I wouldn’t have had to write a patch like this). This should really be possible out of the box, imagine how many cycles are being wasted in people’s experiments that use Transformers for incremental text generation.

2 Likes

Thank you. The major value of this would be the improvement in the time\token metric.
The thing is that every change in the generate function might be not optimal and harm this metric :slight_smile:

I will update after I manage to test it and see if it improves generation times.

See FasterTransformer/gpt_guide.md at main · NVIDIA/FasterTransformer · GitHub they have a solution for that