Why are some weights FP32 in Llama 3.1 405B FBGEMM FP8 Quantization?

I’ve noticed that for “meta-llama/Llama-3.1-405B-Instruct-FP8” some weights are precision F32, some BF16, and F8_E4M3. Meanwhile for “meta-llama/Llama-3.1-405B-Instruct” all weights are BF16.

I understand that when quantizing not all weights can be quantized and the final model will be mixed precision. But why are some weights in the quantized model F32 precision, which is actually higher than the original, un-quantized model?

1 Like

This is an example in 8-bit quantization of bitsandbytes, but the same should be possible in other quantizations with transformers and accerelate.
In other words, they may have set it that way during quantization so that the weights alone could be offloaded to the CPU, since this is too large a model.
I don’t know why only one of the models is that way.

Interesting, I wonder if when loading the model into GPU these weights will be cast to BF16

In the case of BNB’s 8-bit quantization, it seems that the data type is fixed to fp16 and fp32 at the time of computation.
They are probably more concerned with easy CPU offloading.
In the case of NF4 with 4-bit quantization, we can choose any type including BF16. However, I heard that CPU offloading of 4-bit quantization is difficult because torch does not support it.

I wonder if in the case of a pre-mixed non-quantized model, if we call it without options, it will be loaded as is.
Most people load the LLM into CUDA with the torch_dtype=bfloat16 option, though, so I would think it would be cast to BF16.
BF16 would be preferable to float16 for accuracy.

Some weights, like the layernorms, are kept in float32 for stability reasons. See e.g. `layer_norm` needs to be done in fp32 for fp16 inputs · Issue #66707 · pytorch/pytorch · GitHub

1 Like

Right, but for the original model all weights are BF16 so the quantized version actually has some weights that are higher precision (FP32) than the original model.

oh I see, it is only the weight_scale values that are being stored in fp32

I see. So it was implemented that way for a reason.