Hey all!
I think I have a reasonably satisfying answer to this question.
If you load Mistral-7B without quantization, and print out the named parameters from a single decoder layer (along with their shape, number of elements, and whether they’re set to train), here’s what you’ll see:
Parameter Name Dimensions Total Values Trainable
==== Embedding Layer ====
model.embed_tokens.weight 32,000 x 4,096 125M True
==== First Decoder ====
model.layers.0.self_attn.q_proj.weight 4,096 x 4,096 16M True
model.layers.0.self_attn.k_proj.weight 1,024 x 4,096 4M True
model.layers.0.self_attn.v_proj.weight 1,024 x 4,096 4M True
model.layers.0.self_attn.o_proj.weight 4,096 x 4,096 16M True
model.layers.0.mlp.gate_proj.weight 14,336 x 4,096 56M True
model.layers.0.mlp.up_proj.weight 14,336 x 4,096 56M True
model.layers.0.mlp.down_proj.weight 4,096 x 14,336 56M True
model.layers.0.input_layernorm.weight 4,096 x - 4K True
model.layers.0.post_attention_layernorm.weight 4,096 x - 4K True
These are the correct weight matrix shapes and parameter counts.
But with 4-bit quantization enabled, this becomes:
Parameter Name Dimensions Total Values Trainable
==== Embedding Layer ====
model.embed_tokens.weight 32,000 x 4,096 125M True
==== First Decoder ====
model.layers.0.self_attn.q_proj.weight 8,388,608 x 1 8M False
model.layers.0.self_attn.k_proj.weight 2,097,152 x 1 2M False
model.layers.0.self_attn.v_proj.weight 2,097,152 x 1 2M False
model.layers.0.self_attn.o_proj.weight 8,388,608 x 1 8M False
model.layers.0.mlp.gate_proj.weight 29,360,128 x 1 28M False
model.layers.0.mlp.up_proj.weight 29,360,128 x 1 28M False
model.layers.0.mlp.down_proj.weight 29,360,128 x 1 28M False
model.layers.0.input_layernorm.weight 4,096 x - 4K True
model.layers.0.post_attention_layernorm.weight 4,096 x - 4K True
The weight matrices have been flattened and the number of elements has cut in half.
Someone discussed this a bit here saying: “in general the quantized weight is not simply saved as a quantized tensor with X elements each having Y bits, rather it has to be saved as packedparams…”.
So, with the quantized model, if you try to count the model parameters by looping over the weights and tallying their numel
, you’ll get the wrong total.