FlashAttention or equivalent?

Hello - as always a huge thank you in advance to HuggingFace for creating such an amazing and open set of tools.

I am interested in using FlashAttention to achieve longer sequence lengths (and faster training times). Looking here and here it looks like perhaps PyTorch 2.0 has this built into their own transformers library? Does this flow into HuggingFace’s transformers library? Is there a “simple” way to “flip on” FlashAttention in, say, an OPT model (passing in ignore_mismatched_sizes during initialization to make a longer sequence length?). (Using the HuggingFace trainer, fp16 or bf16…).

Totally understand this is possible to do in “lower level” PyTorch. HuggingFace has made it so easy to stay at the “higher level” that I was wondering if it was possible without digging in deeper and changing a bunch of low level code : )

Thank you again.

1 Like