(This comment might be superfluous, but a simple “like” didn’t express it well enough therealadrian )
I just wanted to thank you for the pointer, indeed this same question bugged me for a good while now, and I couldn’t understand how that was possible. I thought the total runtime of a decoder generation would scale with O(n^3), but indeed I was wrong, it seems like it’s “only” O(n^2), where indeed n is simply the sum of prompt and output (although in practice the output is likely slower to compute, since it can’t be parallelized as well on a GPU; and that’s likely why OpenAI charges more for output than for input, but “just” a factor 2 more).
For reference, this is indeed the same for all similar models, like GPT2, BLOOM, etc. - my early experiments with BLOOM puzzled me because runtime didn’t seem to depend on (short) prompt length almost at all, now I get why that is so.
The key part of that blog post which explains why it works so is that GPT2-like models only use masked attention, even in training. That is, in training, if you see the sentence “Hello world, how are you?”, then to compute the key, value, and query vectors for each of the tokens you only look at the previous ones. This of course makes sense because you want to use the outputs to predict the next token (and apply the loss), so you can’t cheat and look at the future; but a priori it’s not a given. You could, for example, use full attention over all the first words when predicting the last “?”, if you only applied loss to that token. In that case, key, query and value vectors of each token would depend on the whole sentence, and that would be okay. But in practice it would be a very bad idea, because the training would be very slow (instead of learning all next words at once, you’d only learn one). I guess that’s how it worked in LSTM times, and why transformers allowed massively more parallelized training. Plus, indeed, inference would be much slower.
To check experimentally that this is indeed the case, one can check that all the predictions (as well as intermediate layers) are identical when prompted with two sentences that only differ in the last token:
import numpy as np
from transformers import AutoTokenizer, GPT2Model
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
outputs = []
for text in ['This is an awesome prompt', 'This is an awesome feature']:
encoded_input = tokenizer(text, return_tensors='pt')
cuda_input = {k: v.to('cuda') for k, v in encoded_input.items()}
outputs.append(model(**cuda_input))
print(np.isclose(outputs[0].last_hidden_state[0].cpu().numpy(), outputs[1].last_hidden_state[0].cpu().numpy()).all(axis=1))
This returns:
[ True True True True False]
i.e. identical logits everywhere except for the very last token.