On this page on GPU Inference, FA-2 section, it says:
FlashAttention-2 can only be used when the model’s dtype is
bf16. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
But then below that, it then says that it can be used with 4bit quantized model:
load in 4bit
model = AutoModelForCausalLM.from_pretrained(
These 2 statements seem contradictory to me, because 4bit model isn’t stored in fp16 right?
Thanks for your clarifications.