Hi, I was exploring the benefits of using flash attention 2 with Mistral and Mixtral during inference. Yet, I can see no memory reduction & no speed acceleration. Some number under different attention implementations:
Mixtral (mistralai/Mixtral-8x7B-Instruct-v0.1):
attn_implementation=‘flash_attention_2’: 27.335Gb, 15.8 seconds
attn_implementation=‘eager’: 27.335Gb, 16.1 seconds
attn_implementation=‘sdpa’: 27.335Gb, 15.4 seconds
Mistral (mistralai/Mistral-7B-Instruct-v0.2):
attn_implementation=‘flash_attention_2’: 6.407Gb, 31.5 seconds
attn_implementation=‘eager’: 6.407Gb, 30.7 seconds
attn_implementation=‘sdpa’: 6.405Gb, 28.9 seconds
(With Mistral it took much more in terms of speed compared to Mixtral because I tested on 20 examples with smaller max_new_tokens
).
The code I used to test Mixtral is taken from HuggingFace page:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1", torch_dtype=torch.float16, attn_implementation=attn_implementation, cache_dir="cache", load_in_4bit=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
prompt = "My favourite condiment is"
model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
%time generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
tokenizer.batch_decode(generated_ids)[0]
Library versions:
transformers==4.37.2
torch==2.2.0
flash-attn==2.5.3
My GPU is NVIDIA A100-SXM4-40GB with cuda release 12.3, V12.3.107.
I know that the major benefit of flash-attention-2 blossoms out during training, yet, it is reported to be beneficial during inference as well: HuggingFace Page. Hence, I wonder whether I may have any problems with my environment or there are any problems with the current implementations.
Thanks in advance!