FlashAttention-2's 16 bit requirement

On this page on GPU Inference, FA-2 section, it says:

FlashAttention-2 can only be used when the model’s dtype is fp16 or bf16. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.

But then below that, it then says that it can be used with 4bit quantized model:

load in 4bit

model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
attn_implementation=“flash_attention_2”,
)

These 2 statements seem contradictory to me, because 4bit model isn’t stored in fp16 right?

Thanks for your clarifications.

Hi @peterhung! Indeed, 4-bit and 8-bit quantization through bitsandbytes enables to reduce the memory footprint of the model. However, when the output of a layer is being computed, the weights of this layer are casted to 32-bit or 16-bit precision. Which is why it is compatible with Flash Attention.

Here is a more detailed explanation: Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA

Hi @regisss

Thanks for your explanations - they’re clear.

And thanks for the link - I’d check it out as well.