I recently law this Twitter thread discussion https://twitter.com/_basilM/status/1625185484837208082?s=20 where they discuss that adding a layer norm to the Key & Query matrices stabalizes training in transformers/Large Language Models (LLMs)/Foundation Models. I wanted this stability – especially in pre-trained HuggingFace (HF) models.
Thus, my question is: What is the recommended/best way to add this layer norm to the K/Q (Key Query) activation matrices to (any) hugging face models?
I assume it doesn’t really make a big difference if it’s pre-trained or not (hopefully). I assume I could copy paste some code and put it in but wanted to see what is there was a less naive way to do it. I assume to deal with pre-trained models if I do my current suggestion, I can simply load the weights and fine-tune it a little to make sure the old weights “become aware” there is a new layer norm layer there (and adapt any missing parameters in the layer norm if needed).
Thoughts welcomed!
(other suggestions not using HF are welcomed too! But I assume they basically require implementing transformers form scratch and adding layer norm using either pytorch/TF/jax)
refs:
- twitter cross post: https://twitter.com/BrandoHablando/status/1627047211807961088?s=20
- reddit: https://www.reddit.com/r/pytorch/comments/115qkt0/how_to_implement_key_query_layer_normalized/
- quora: https://www.quora.com/unanswered/How-do-I-implement-key-query-layer-normalized-transformers-LLMs-in-Huggingface