I am calculating some values using the model weights and its input.
I want to use “meta-llama/Meta-Llama-3-8B-Instruct” for Generation task.
The weights of the first transformer layer have the following shape:
model.embed_tokens.weight: torch.Size([128256, 4096])
model.layers.0.self_attn.q_proj.weight: torch.Size([4096, 4096])
model.layers.0.self_attn.k_proj.weight: torch.Size([1024, 4096])
model.layers.0.self_attn.v_proj.weight: torch.Size([1024, 4096])
model.layers.0.self_attn.o_proj.weight: torch.Size([4096, 4096])
model.layers.0.mlp.gate_proj.weight: torch.Size([14336, 4096])
model.layers.0.mlp.up_proj.weight: torch.Size([14336, 4096])
model.layers.0.mlp.down_proj.weight: torch.Size([4096, 14336])
model.layers.0.input_layernorm.weight: torch.Size([4096])
model.layers.0.post_attention_layernorm.weight: torch.Size([4096])
Now, when I am using the quantized version of “unsloth/llama-3-8b-bnb-4bit”.
The weights of the first transformer layer have the following shape:
model.embed_tokens.weight: torch.Size([128256, 4096])
model.layers.0.self_attn.q_proj.weight: torch.Size([8388608, 1])
model.layers.0.self_attn.k_proj.weight: torch.Size([2097152, 1])
model.layers.0.self_attn.v_proj.weight: torch.Size([2097152, 1])
model.layers.0.self_attn.o_proj.weight: torch.Size([8388608, 1])
model.layers.0.mlp.gate_proj.weight: torch.Size([29360128, 1])
model.layers.0.mlp.up_proj.weight: torch.Size([29360128, 1])
model.layers.0.mlp.down_proj.weight: torch.Size([29360128, 1])
model.layers.0.input_layernorm.weight: torch.Size([4096])
model.layers.0.post_attention_layernorm.weight: torch.Size([4096])
As per my limited knowledge, in the quantization step, we convert the float16 or float32 value to int4 or int8. Also, to fast access, the weights are reshaped into 1-D.
But when you look for, let’s say, weights of q_proj in self-attention of the first transformer layer of “meta-llama/Meta-Llama-3-8B-Instruct”, the weight shape is torch.Size([4096, 4096]).
When you convert it into1-D, it will be (16777216, 1). But if you look at the shape of the corresponding weight in “unsloth/llama-3-8b-bnb-4bit”, it is ** torch.Size([8388608, 1])**.
I have two questions:
- How the quantized weight shape in this particular case is ** torch.Size([8388608, 1])**?
- If I want to reshape the weight for some calculation, how can I do it (from ** torch.Size([8388608, 1])** to torch.Size([4096, 4096]).
2024-09-22T18:30:00Z