On this page on GPU Inference, FA-2 section, it says:
FlashAttention-2 can only be used when the model’s dtype is
fp16
orbf16
. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
But then below that, it then says that it can be used with 4bit quantized model:
load in 4bit
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
attn_implementation=“flash_attention_2”,
)
These 2 statements seem contradictory to me, because 4bit model isn’t stored in fp16 right?
Thanks for your clarifications.