SDPA attention in e.g. Llama does not use fused accelerations

Hi, I’m worried the attention implementation that does rely on pytorch’s SCALED_DOT_PRODUCT_ATTENTION does not use it’s full potential.

Indeed, the function call at line 673 of modeling_llama.py does not use the argument ‘is_causal’ which allows for fused implementations (Accelerated PyTorch 2 Transformers | PyTorch):

" At present, the only attention mask supported by fused kernel implementation is the causal mask commonly used for training. To specify the causal mask in custom kernels, it must be specified with the is_causal boolean and attn_mask must be None".

Hope this can help,
Anthony.