In the link above, they talk about batching with flash attention. Though They seem to say that we should put all batches into one sequence rather than the usualy batching and padding approach.
Im really quite lost, it would be really useful to see an example of how to implement this.
Bit of a late response here, but I believe I can provide an answer. I was curious as to the reason flex-attention (PyTorch’s implementation of flash attention et al.) worked much better when without batching and found the answer to your question in my search.
Anyways, what you’re looking for is to create a block mask (a matrix that defines which queries can attend to which keys) in which the tokens in the same batch (before flattening) can only attend to tokens that are: a) prior/causal, and b) in the same batch.
FlexAttention calls this document masking. An implementation can be seen here: