Inconsistent output between flash attention and eager

If we train a model in bfloat16 with flash attention , the model is working fine when i use flash attention in inference , but when i use eager mode the values are different. Does something like this happen or may be I doing something wrong in my code.

1 Like

Outputs may vary slightly due to differences in backends, attention mechanisms, or the runtime data types of model weights. However, unless it’s a bug, the results will generally be similar (within a few percent difference).


Short answer:
Yes, this absolutely can happen, especially in bfloat16. Different attention backends (eager vs FlashAttention2) are allowed to give slightly different numeric results. The key questions are:

  • How big are the differences?
  • Do they hurt loss/accuracy or change generated text in a major way?

If differences are small (e.g., 1e-3–1e-2 on logits), this is expected. If they are large or change behavior a lot, you might be hitting a masking or version bug.

Below is a structured explanation with background and concrete checks.


1. Background: what “eager” vs “flash_attention_2” actually are

In recent PyTorch + Hugging Face stacks, the “attention implementation” is pluggable:

  • attn_implementation="eager"
    Classic formulation using standard PyTorch ops (matmul, softmax, matmul). No special kernel.

  • attn_implementation="sdpa"
    Uses torch.nn.functional.scaled_dot_product_attention with one of several fused backends (math, efficient, flash, cuDNN).(PyTorch Docs)

  • attn_implementation="flash_attention_2"
    Uses the FlashAttention2 CUDA kernels from Dao et al. (Hugging Face integrates the external flash-attn library).(GitHub)

All three implement the same math in theory, but:

  • The operations are fused differently.
  • The order of additions/multiplications is different.
  • Some keep more intermediate values in low precision.

In infinite precision they’d match exactly; in floating point they don’t.


2. Why you see different outputs

2.1 Different kernels ⇒ different floating-point rounding

PyTorch’s SDPA and backends explicitly warn that different kernels can produce different results because they reorder operations and fuse things.(PyTorch Docs)

FlashAttention2 goes even further than SDPA:

  • It tiles Q/K/V in blocks.
  • It recomputes parts of the softmax in a numerically stable way.
  • It minimizes HBM I/O (GPU memory traffic).(GitHub)

These are all mathematically equivalent to standard attention, but:

  • Summations happen in a different order.
  • Rounding happens at different points.
  • That directly changes the final floating-point values.

So, even with the same weights, same input, same model, just changing attn_implementation is enough to change numbers.

2.2 bfloat16 makes differences bigger

You are using bfloat16. This matters a lot:

  • bfloat16 = 8-bit exponent (like float32) but only 7 bits of mantissa.
  • Less mantissa bits ⇒ fewer significant digits ⇒ larger rounding error.
  • If you change operation order (which all fused kernels do), rounding changes.

FlashAttention2 requires fp16/bf16 and is specifically advertised for those.(Hugging Face)
In that regime, it is completely normal to see:

  • Larger numeric drift vs a float32 “reference”.
  • More sensitivity to reordering (eager vs flash).

So:

  • Eager (often more float32 intermediates) vs FlashAttention2 (more low-precision) will not numerically match in bfloat16.
  • That mismatch doesn’t automatically mean your code is wrong.

2.3 FlashAttention2 is not perfectly deterministic

There are multiple issues in the wild showing non-determinism or large discrepancies:

  • HF: “Transformer models are not deterministic when using FlashAttention2” – repeated runs with FA2 give slightly different logits even with determinism flags; eager and SDPA do not.(GitHub)
  • HF: issue on FA2 vs SDPA with attention masks giving inconsistent outputs.(GitHub)
  • HF: issues where using FA2 changes loss curves compared to SDPA or eager.(GitHub)
  • FlashAttention repo: explicit report where FA2 and PyTorch SDPA differ strongly for some configs.(GitHub)

FlashAttention2 has added “deterministic” options recently, but determinism is still a known pain point.(GitHub)

Conclusion: if you compare FA2 vs eager and ask “are they bit-identical?” the answer is almost always “no,” especially in bfloat16.


3. What’s “normal” vs “suspicious”?

You need to distinguish expected numerical noise from real problems.

3.1 “Normal” differences (usually OK)

For a single forward pass with:

  • same model weights
  • same input
  • model.eval() (no dropout)
  • only attn_implementation changed

If you measure on the logits:

  • max_abs_diff ≈ 1e-3–1e-2
  • mean_abs_diff much smaller than the typical logit magnitude
  • Argmax (top-1 predicted token) is the same or only rarely different
  • Loss/perplexity/accuracy change only in the 3rd–4th decimal place

That scale of difference is expected when:

  • Switching between eager / SDPA / FA2.(Medium)
  • Working in bfloat16/fp16 instead of float32.

In this case, your code is probably fine. You are just seeing floating-point behavior.

3.2 “Suspicious” differences (worth debugging or reporting)

Red flags:

  • max_abs_diff on logits ≳ 1.0, mean_abs_diff ≳ 0.1–1.0.
  • Generated text is dramatically different (not just minor variations).
  • Loss/accuracy curves diverge clearly between backends, not just tiny noise.
  • Only one backend (e.g., eager) produces reasonable outputs; FA2 outputs are consistently “bad,” or vice versa.(GitHub)

These patterns often indicate:

  • Attention mask being interpreted differently in each path.
  • A regression/bug in a specific model + Transformers + PyTorch + flash-attn version combo.(GitHub)

4. How to check your setup step-by-step

4.1 Minimal test to quantify the difference

Take your trained model and run a simple comparison:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-3.2-1B-Instruct"  # example; use your model
# Docs: https://huggingface.co/docs/transformers/en/main_classes/model  # noqa

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    attn_implementation="flash_attention_2",  # HF GPU perf docs  # noqa
)

model.eval()  # very important: turn off dropout

inputs = tokenizer("Hello world", return_tensors="pt").to(model.device)

with torch.no_grad():
    logits_flash = model(**inputs).logits

# Switch attention backend in-place
model.set_attn_implementation("eager")  # see attention interface docs  # noqa

with torch.no_grad():
    logits_eager = model(**inputs).logits

diff = (logits_flash - logits_eager).abs()
print("max abs diff:", diff.max().item())
print("mean abs diff:", diff.mean().item())

Key details:

  • model.eval() ensures no dropout, so differences come from kernels, not randomness.
  • Same model instance, just set_attn_implementation changes the backend.(Hugging Face)

Interpretation:

  • Small max/mean diff ⇒ expected.
  • Huge diff ⇒ go on to the checks below.

4.2 Check masks and “causal” settings

Many real bugs come from masks, not from the core attention math:

  • Some backends expect a boolean mask (True/False).
  • Some expect an additive mask (0 or −inf).
  • FlashAttention2 has limitations around certain mask types and padding handling; some projects explicitly warn that FA2 doesn’t support general masks the same way SDPA does, and you must use “unpadded attention” patterns instead.(Hugging Face)

Questions to ask about your code:

  • Are you passing a custom attention_mask or key_padding_mask?
  • Is it 0/1, bool, or contains −inf?
  • Are you using a non-standard pattern (e.g., custom causal mask, sliding window, multi-modal masks)?

If yes, the mask may be interpreted differently by eager vs FA2. HF has at least one closed bug specifically about inconsistent outputs with FA2 and SDPA when an attention mask is used.(GitHub)

4.3 Use the SDPA “math” backend as a reference

If you want a “golden” baseline:

  • PyTorch’s sdpa_kernel lets you force the math backend (less fused, more float32 intermediates).(PyTorch Docs)

Example idea:

import torch
from torch.nn.attention import sdpa_kernel, SDPBackend

# Docs: https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html  # noqa

with sdpa_kernel(backends=[SDPBackend.MATH]):
    model.set_attn_implementation("sdpa")
    with torch.no_grad():
        logits_math = model(**inputs).logits

# Now compare:
# - FA2 vs math
# - eager vs math

If both eager and FA2 are reasonably close to logits_math but not identical to each other, that again points to normal kernel differences, not a serious bug.

4.4 Check versions and compatibility

You want a version stack where this is known to work reasonably:

  • Recent transformers (v4.40+ or whatever you’re on).
  • PyTorch version that your GPUs support well for SDPA and FA2.
  • A flash-attn version that matches your PyTorch and GPU.(GitHub)

There are GitHub issues where certain versions:

  • Make FA2 non-deterministic even when determinism flags are set.(GitHub)
  • Produce clearly wrong outputs for certain models with FA2.(GitHub)

If your versions match a known problematic combination, upgrading often fixes it.


5. Practical recommendations for your use case

Given your description:

Trained in bfloat16 with FlashAttention; inference with flash is fine; inference with eager gives different values.

Reasonable interpretation:

  • Training and inference should generally use the same attention backend if you care about maximum consistency.
  • Using FA2 both for training and inference is common and perfectly fine, especially when performance matters.(Hugging Face)

So concretely:

  1. If differences are small (e.g., max_abs_diff ≲ 1e-2, predictions basically the same):

    • This is normal.
    • You can safely stick to FlashAttention2 for both training and inference.
  2. If differences are big and hurt quality (loss, accuracy, or generated outputs):

    • Check masks, eval() mode, version compatibility.
    • Compare with SDPA (math backend) as a reference.
    • If FA2 or eager clearly looks “wrong” vs math, you may have found a real bug; open a minimal repro issue on the relevant GitHub (Transformers or FlashAttention).
  3. For maximum stability:

    • Choose one backend (eager or SDPA or FA2) and keep it fixed for both training and deployment.
    • Document the exact stack (PyTorch, Transformers, flash-attn versions and flags).

6. Curated references and further reading

Hugging Face / Transformers

  • HF GPU performance & FlashAttention2 guide (when FA2 is used, dtype requirements, how to enable) – good high-level background on what FA2 is and how Transformers uses it.(Hugging Face)
  • Transformers model base docs (how attn_implementation is wired and how set_attn_implementation works).(Hugging Face)

PyTorch attention backends

  • scaled_dot_product_attention docs – explains how SDPA works and that different backends can yield different results.(PyTorch Docs)
  • sdpa_kernel + SDPBackend docs – how to force specific backends like MATH, FLASH_ATTENTION, EFFICIENT_ATTENTION. Useful for debugging and for building a “reference” baseline.(PyTorch Docs)

FlashAttention / issues about discrepancies

  • FlashAttention GitHub repo – implementation details and algorithm background.(GitHub)
  • “Output Discrepancy Between FlashAttention and PyTorch Attention” (FlashAttention issue) – concrete example where differences are large and considered a bug.(GitHub)
  • HF issues on FA2 nondeterminism and inconsistent outputs with masks or specific models.(GitHub)

If you measure your actual max_abs_diff / mean_abs_diff and describe whether the generated text or metrics are significantly worse in eager mode, you can directly compare your situation against the “normal vs suspicious” criteria above.

Thanks @John6666 ! The difference is on the order of 1e-03, so I believe that’s expected. Thanks for the reply.

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.