Enabling Flash Attention 2

What is the difference between using Flash Attention 2 via

model = AutoModelForCausalLM.from_pretrained(ckpt, attn_implementation = "sdpa")

vs

model = AutoModelForCausalLM.from_pretrained(ckpt, attn_implementation = "flash_attention_2")

when Pytorch SDPA support FA2 according to docs ?

@marcsun13

2 Likes

@ybelkada Can you shed some light on this?

@varadhbhatnagar were you able to figure out the difference?