Token merging for fast LLM inference

Hello all,

I worked on a project aiming at speeding up inference of LLMs by using merging of sequence. The core idea is that to predict the nth token, the model does not need the 1 to n-1th tokens and we could merge them using SLERP. I did a first job with Mistral 7B instruct and it turns out it works.
The sequence is reduce by a factor of more or less 2 and the quality of the output is still satisfying. I put my code here : GitHub - samchaineau/llm_slerp_generation: Repo hosting codes and materials related to speeding LLMs' generative abilities while preserving quality using token merging.

Here is a scheme representing my view :

If anyone is interested, reach out to me ! I think this could be an asset in the accelerate library

A demo where I generate >128 tokens with just 95 elements in the sequence.

1 Like