Convert pre-trained MHA weights to GQA weights

In LlamaConfig, the field num_key_value_heads stated that: “When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group.” Is there any transformers implementation for this conversion?

1 Like

I also have this question now. Is there a solution?

1 Like