Less Trainable Parameters after quantization

So I went and investigate into the actual counting of parameters and this is the behaviour (using Mistral7B as example). The quantized layers have half the parameters of the non-quantized layers

And the reason for this is the actual Linear4bit implementation (bitsandbytes/bitsandbytes/nn/modules.py at 048a2d404c6a909e6f835ba18182fbfae130ba09 · TimDettmers/bitsandbytes · GitHub)

Gemini explains this better than I can jeje:

Let's consider a simplified example with a Linear4bit layer having 4 input features and 4 output features. This means the original weight tensor (before quantization) would have a shape of [4, 4], containing 16 individual weights.

1. Original Weight Tensor (FP16):
weights = torch.randn(4, 4, dtype=torch.float16)
This might output something like:
tensor([[ 0.2344, -0.1234,  0.5678, -0.9876],
        [-0.4567,  0.8765, -0.3456,  0.7890],
        [ 0.1234, -0.5678,  0.9876, -0.2345],
        [-0.7890,  0.3456, -0.7890,  0.1234]], dtype=torch.float16)

2. Quantization and Packing:
When we quantize this layer using Linear4bit, each of these 16 weights will be converted to a 4-bit representation. Since 8 bits make up a byte, we can pack two 4-bit weights into a single element of a torch.uint8 tensor.

3. Packed Weight Tensor (uint8):
After quantization and packing, the weight tensor will have a shape of [8, 1]. Each element in this tensor will hold two 4-bit quantized weights.
The exact values in the packed tensor will depend on the specific quantization map used (e.g., fp4 or nf4). However, the key point is that the information from the original 16 weights is now stored in 8 elements, effectively achieving a 50% reduction in memory usage for the weights.
1. Original Shape:
We start with a weight tensor of shape [4, 4], containing 16 individual 16-bit weights.
2. 4-bit Quantization:
Each 16-bit weight is converted to a 4-bit representation. We now have 16 4-bit weights.
3. Packing:
Since 8 bits make up a byte, we can pack two 4-bit weights into a single element of a torch.uint8 tensor.
Therefore, the 16 4-bit weights can be packed into 8 elements.
4. Final Shape:
The most logical and efficient way to store these 8 packed elements would be in a 1D tensor of shape [8].
However, the Linear4bit layer implementation uses a slightly different approach for internal reasons related to memory alignment and computation.
It stores the packed weights in a 2D tensor of shape [8, 1].
This shape technically has the same number of elements (8) as the 1D [8] shape, but it's organized differently in memory.