How to implement Key Query Layer Normalized Transformers/LLMs in Huggingface?

I recently law this Twitter thread discussion https://twitter.com/_basilM/status/1625185484837208082?s=20 where they discuss that adding a layer norm to the Key & Query matrices stabalizes training in transformers/Large Language Models (LLMs)/Foundation Models. I wanted this stability – especially in pre-trained HuggingFace (HF) models.

Thus, my question is: What is the recommended/best way to add this layer norm to the K/Q (Key Query) activation matrices to (any) hugging face models?

I assume it doesn’t really make a big difference if it’s pre-trained or not (hopefully). I assume I could copy paste some code and put it in but wanted to see what is there was a less naive way to do it. I assume to deal with pre-trained models if I do my current suggestion, I can simply load the weights and fine-tune it a little to make sure the old weights “become aware” there is a new layer norm layer there (and adapt any missing parameters in the layer norm if needed).

Thoughts welcomed!

(other suggestions not using HF are welcomed too! But I assume they basically require implementing transformers form scratch and adding layer norm using either pytorch/TF/jax)

refs: