Suppose I have Llama-3.2-3B model as follow:
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 3072)
(layers): ModuleList(
(0-27): 28 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=3072, out_features=3072, bias=False)
(k_proj): Linear(in_features=3072, out_features=1024, bias=False)
(v_proj): Linear(in_features=3072, out_features=1024, bias=False)
(o_proj): Linear(in_features=3072, out_features=3072, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
(up_proj): Linear(in_features=3072, out_features=8192, bias=False)
(down_proj): Linear(in_features=8192, out_features=3072, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((3072,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=3072, out_features=128256, bias=False)
)
How do I separate attention heads from (q_proj), (k_proj), and (v_proj) matrices?
Suppose #head is known, my main concern is should I do this row-wise or column-wise?