Flash attention has no effect on inference

Hi, I was exploring the benefits of using flash attention 2 with Mistral and Mixtral during inference. Yet, I can see no memory reduction & no speed acceleration. Some number under different attention implementations:

Mixtral (mistralai/Mixtral-8x7B-Instruct-v0.1):
attn_implementation=‘flash_attention_2’: 27.335Gb, 15.8 seconds
attn_implementation=‘eager’: 27.335Gb, 16.1 seconds
attn_implementation=‘sdpa’: 27.335Gb, 15.4 seconds

Mistral (mistralai/Mistral-7B-Instruct-v0.2):
attn_implementation=‘flash_attention_2’: 6.407Gb, 31.5 seconds
attn_implementation=‘eager’: 6.407Gb, 30.7 seconds
attn_implementation=‘sdpa’: 6.405Gb, 28.9 seconds

(With Mistral it took much more in terms of speed compared to Mixtral because I tested on 20 examples with smaller max_new_tokens).

The code I used to test Mixtral is taken from HuggingFace page:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto

model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", torch_dtype=torch.float16, attn_implementation=attn_implementation, cache_dir="cache", load_in_4bit=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

prompt = "My favourite condiment is"

model_inputs = tokenizer([prompt], return_tensors="pt").to(device)

%time generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
tokenizer.batch_decode(generated_ids)[0]

Library versions:

transformers==4.37.2
torch==2.2.0
flash-attn==2.5.3

My GPU is NVIDIA A100-SXM4-40GB with cuda release 12.3, V12.3.107.

I know that the major benefit of flash-attention-2 blossoms out during training, yet, it is reported to be beneficial during inference as well: HuggingFace Page. Hence, I wonder whether I may have any problems with my environment or there are any problems with the current implementations.
Thanks in advance!

@ybelkada hope you can shed some light here

Also experiencing this using Mistral-7B-v0.1 on a 4090. Running for:

  • 4096 tokens
  • batch size of 1
  • No padding

Multi-head attention takes ~144 seconds.
Flash-attention takes ~154.

Hi everyone !
@Aktsvigun to see the effect of FA2 you need to run inference on a large context length. In the benchmarks I ran where I saw interesting speedups I used ~2048 tokens
I used this script to benchmark everything: Benchmark FA2 + transformers integration · GitHub
@kreas can you try with transformers==4.35.2 ?

1 Like

Changing transformers version doesn’t seem to affect anything;

I found out that I had some NVTX calls which added overhead, but even after removing, Flash attention is slower on mistral. I’ve slightly modified the script you linked in your comment. These are the results:

image

It speeds up other models such as phi-2 for the same text around 3x, while it actually slows down mistral v0.2 instruct.

I think mistral even say in some of descriptions of the model it is very fast because uses flash attention. Not sure. Maybe it it used by default if available.

@Aktsvigun try

torch_dtype=torch.bfloat16

Hi! I thought I would chime in here since I’m also trying to figure out how much flash_attention is improving inference speed. I modified Younes Belkada’s script here and got some results.

I found that FA improvements were only realized after context length went past 4k and when use_cache was set to True.