Hi,
I was curious to know how the number of forward passes scales with the number of beams in generate
method in transformers
. My idea is that for a greedy generation itself we would require at least max_seq_length number of forward passes assuming we are predicting one token at a time. So, to predict the first word we do one forward pass, and to predict the second word we make k forward passes, and selecting k top tokens from each forward passes gives me k^2 options and it does k^2 forward passes to predict the 3rd word. So, the total number of forward passes becomes a sum of a geometric series i.e., 1+k+k^2 + k^3...n terms, and hence O(k^n). Could anyone show some light on this or tell if my understanding is correct?