KNN-LM with Clustering Centroids for Continuous Learning

One of the challenges in training LLMs is catastrophic forgetting, or the high cost of continuously retraining the model. I have experimented with a new architecture idea, combining a kNN-LM with clustering to keep the memory size constant. The kNN-LM setup is described here:

https://ai.meta.com/research/publications/generalization-through-memorization-nearest-neighbor-language-models/

In that approach, a Transformer is paired with a k-nearest neighbors algorithm. The kNN component works like a vector database: the Transformer produces a context representation, which is compared with stored entries that contain both a context and the next-token label. With a kNN-LM, new facts could simply be inserted into the vector database, allowing the model to improve continuously, but the database becomes impractically large.

Instead of letting the database grow indefinitely, similar entries could be clustered and replaced with a single centroid, so the overall size remains fixed. For example, if the database reaches 100M entries, we cluster them until it’s reduced back to 10M. This is analogous to human memory, where similar experiences are merged, generalized, and not redundantly stored.

This method could address catastrophic forgetting and the high cost of continual retraining. New facts can simply be added to the vector database, while clustering ensures it stays efficient and practical over time. Here is a link to the code:

https://github.com/Ascanius365/KNN-LM-with-Clustering-Centroids

1 Like