Different output when we inference through packing with flash attention in bf16

i am trying to pack the tokens and see if the output remain same with and without packing

when i used f16 precision results for with packing and without packing remain same, but when i tried with bf16 i found slight difference in the output

Below is the code

import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-0.6B', padding_side='left')
model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B" ,  attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16).cuda()

# ignore the example text :)
texts = ["hi how are you this is what so that is that what is know is there any way to do this" , " hi how can i help you today" ]
input_ids  = tokenizer(
    texts,
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt",
)
input_ids = { k: v.cuda() for k, v in input_ids.items() }


# packing the tokens
input_ids_with_packing   = tokenizer(
    texts,
    padding=False,
    truncation=False,
    max_length=512,
)
input_with_packing = {"input_ids" : [], "position_ids" : []}
index =1
for i in range(len(input_ids_with_packing["input_ids"])):
    input__ = input_ids_with_packing["input_ids"][i]
    input_with_packing['input_ids'] += input__

    index = index+1
    input_with_packing['position_ids'] += list(range(len(input_ids_with_packing["input_ids"][i])))
    

input_with_packing = {k: torch.tensor([v]).cuda() for k, v in input_with_packing.items()}

so the final inputs will be like

# without packing
{'input_ids': tensor([[  6023,   1246,    525,    498,    419,    374,   1128,    773,    429,
             374,    429,   1128,    374,   1414,    374,   1052,    894,   1616,
             311,    653,    419, 151643],
         [151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
          151643, 151643, 151643, 151643, 151643,  15588,   1246,    646,    600,
            1492,    498,   3351, 151643]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]],
        device='cuda:0')}



# with packing
{'input_ids': tensor([[  6023,   1246,    525,    498,    419,    374,   1128,    773,    429,
             374,    429,   1128,    374,   1414,    374,   1052,    894,   1616,
             311,    653,    419, 151643,  15588,   1246,    646,    600,   1492,
             498,   3351, 151643]], device='cuda:0'),
 'position_ids': tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21,  0,  1,  2,  3,  4,  5,  6,  7]], device='cuda:0')}

when i run with these inputs

with torch.no_grad():
    direct_outputs = model(**input_ids)
    packing_outputs = model(**input_with_packing)

the last_hidden_state outputs are

# without packing
 tensor([[[ 2.4062e+00, -1.4609e+00, -1.8848e-01,  ..., -8.4375e+00,
           -1.1625e+01,  1.3504e-03],
          [-3.9375e+00, -3.2656e+00, -1.2500e+00,  ...,  1.5078e+00,
           -1.7090e-01,  1.0000e+00],
          [-2.4219e+00, -5.9688e+00, -1.4766e+00,  ...,  8.4375e-01,
           -1.0156e+00, -2.3281e+00],
          ...,
          [-1.6406e+00, -8.6875e+00, -1.3281e+00,  ..., -5.7812e-01,
           -1.4844e+00, -1.5625e-01],
          [-1.7344e+00, -6.6875e+00, -1.3594e+00,  ..., -3.1982e-02,
           -5.2344e-01, -1.7480e-01],
          [-9.8047e-01, -1.7188e+00, -9.6484e-01,  ..., -9.9121e-02,
            2.2344e+00,  4.6250e+00]],
 
         [[-4.3438e+00,  1.6797e-01,  2.5586e-01,  ...,  3.7969e+00,
            6.6250e+00, -3.9688e+00],
          [-4.3438e+00,  1.6797e-01,  2.5586e-01,  ...,  3.7969e+00,
            6.6250e+00, -3.9688e+00],
          [-4.3438e+00,  1.6797e-01,  2.5586e-01,  ...,  3.7969e+00,
            6.6250e+00, -3.9688e+00],
          ...,
          [ 5.7031e-01, -5.2188e+00, -1.1094e+00,  ...,  1.0000e+00,
           -4.0938e+00, -1.3984e+00],
          [-2.5000e+00, -3.6719e+00, -1.2812e+00,  ...,  1.1328e+00,
           -4.5508e-01, -2.4375e+00],
          [-1.6016e+00,  5.1953e-01, -1.1016e+00,  ...,  9.2188e-01,
            5.0000e-01,  1.7266e+00]]], device='cuda:0', dtype=torch.bfloat16))


# with packing
(tensor([[[  2.5312,  -1.4531,  -0.1982,  ...,  -8.4375, -11.7500,  -0.0121],
          [ -3.9688,  -3.1562,  -1.2656,  ...,   1.5703,  -0.1865,   0.8359],
          [ -2.3906,  -5.9062,  -1.4766,  ...,   0.9258,  -1.0234,  -2.2188],
          ...,
          [  0.5078,  -5.1250,  -1.1250,  ...,   1.0938,  -4.0625,  -1.3594],
          [ -2.5625,  -3.5938,  -1.2656,  ...,   1.1484,  -0.4316,  -2.4219],
          [ -1.7188,   0.5547,  -1.0938,  ...,   0.9609,   0.5977,   1.7031]]],
        device='cuda:0', dtype=torch.bfloat16),

is there any reason for this ? or am i doing something wrong here

1 Like

or am i doing something wrong here

Seems nothing wrong. Well, it’s common for differences in float precision or attention backends to slightly alter the output. Occasionally, there are also cases stemming from actual bugs.


Yes. There is a reason, and in your case it is most likely expected behavior, not a serious mistake.

Main answer

Your two forwards are not the same low-level computation, even if they represent the same logical text inputs:

  • without packing: padded batch, shape roughly [2, 22], with an attention_mask
  • with packing: flattened batch, shape [1, 30], with reset position_ids and no attention_mask

For recent Hugging Face transformers, that packed form is the intended padding-free style for FlashAttention-based packing: Qwen3 forwards attention_mask and position_ids into create_causal_mask(...), and the packed-sequence path is triggered from position_ids when attention_mask is None. TRL’s padding-free docs and code describe the same pattern: flatten sequences and return position_ids instead of attention_mask. (GitHub)

So the packed input structure you built is conceptually valid for a modern stack.


Why BF16 can differ while FP16 looked the same

There are two separate effects here.

1) Different tensor layout means different accumulation order

PyTorch explicitly says that mathematically equivalent batched computations are not guaranteed to be bitwise identical to slice-by-slice or differently-shaped computations. It also says that applying an operation to a slice is not guaranteed to match slicing the result of the full operation. Your padded run and packed run have different shapes and different masking/layout behavior, so this warning applies directly. (PyTorch Documentation)

2) Fused attention backends can change the result slightly

PyTorch’s SDPA docs state that because floating-point operations are fused, the output may differ depending on which backend kernel is chosen. The docs also note that the math backend keeps intermediates in torch.float for torch.half and torch.bfloat16 inputs, which is one reason fused FA-style paths and higher-precision reference paths do not match exactly. (PyTorch Documentation)

BF16 also has a PyTorch-documented reduced-precision GEMM reduction mode that is enabled by default and can be disabled if you see unwanted numerical effects. (PyTorch Documentation)

So “FP16 looked identical to me, BF16 shows a small drift” is plausible. It does not automatically mean the BF16 packed path is wrong. (PyTorch Documentation)


There is a very similar upstream report

A close public analogue exists in the FlashAttention repo: a user compared flash_attn_func and flash_attn_varlen_func in a BF16 setup and found small numerical differences when strict equality was expected. That is very close to your situation, because “normal path” versus “varlen / packed path” is essentially the same class of comparison. (GitHub)

There is also a Transformers issue reporting that flash_attention_2 is not deterministic in the same way as sdpa or eager, which further supports the idea that strict equality is the wrong expectation once FA2 is involved. (GitHub)


What you are doing right

Left padding

Qwen’s own model card recommends using padding_side="left" together with flash_attention_2 for this model family. (Hugging Face)

Packing via flattened input_ids + reset position_ids

That matches the HF packing design for models that use position_ids. The packing blog and TRL docs describe padding-free packing exactly in those terms. (Hugging Face)

Not passing attention_mask in the packed case

That is aligned with TRL’s padding-free implementation, which returns position_ids instead of attention_mask. (GitHub)


What is weak in the current comparison

1) You are comparing raw last_hidden_state

For an embedding model, that is the most sensitive thing you can compare. Qwen’s model card shows the intended inference path as:

  1. run the model,
  2. take last_hidden_state,
  3. do last-token pooling using the attention mask,
  4. normalize the embedding. (Hugging Face)

That means the more meaningful comparison is the final pooled normalized embedding, not raw hidden states at every token position.

2) You need to compare the correct slices

Your padded batch and packed batch are laid out differently.

For your example:

  • sequence 1 in the padded batch is direct_outputs.last_hidden_state[0, :len1]
  • sequence 2 in the padded batch is direct_outputs.last_hidden_state[1, -len2:] because of left padding
  • sequence 1 in the packed batch is packing_outputs.last_hidden_state[0, :len1]
  • sequence 2 in the packed batch is packing_outputs.last_hidden_state[0, len1:len1+len2]

If you compare whole tensors by eye, you are mixing “different positions in different layouts,” which makes the discrepancy look worse than it is. This is a comparison problem, not necessarily a model problem.


My judgment for your exact snippet

I would rank the explanations like this:

Most likely

Expected BF16 + FlashAttention2 numeric drift due to different shapes, different masking/layout, and fused-kernel accumulation order. (PyTorch Documentation)

Also likely

Your comparison method overstates the difference, because you are visually comparing raw last_hidden_state rather than aligned token slices and final pooled embeddings. (Hugging Face)

Less likely, but worth checking

Version-sensitive packing behavior. This code path has been under active change, and there are public issues around packed-mask handling and position_ids-based detection in transformers. (GitHub)


What I would change first

1) Put the model in eval mode

Do this before comparing anything:

model.eval()

That removes training-time randomness as a confounder.

2) Compare aligned token slices, not whole tensors

with torch.no_grad():
    direct = model(**input_ids).last_hidden_state
    packed = model(**input_with_packing).last_hidden_state

len1 = len(input_ids_with_packing["input_ids"][0])
len2 = len(input_ids_with_packing["input_ids"][1])

direct_1 = direct[0, :len1]
packed_1 = packed[0, :len1]

direct_2 = direct[1, -len2:]
packed_2 = packed[0, len1:len1+len2]

print("seq1 max abs diff:", (direct_1.float() - packed_1.float()).abs().max().item())
print("seq2 max abs diff:", (direct_2.float() - packed_2.float()).abs().max().item())

print("seq1 allclose:", torch.allclose(direct_1.float(), packed_1.float(), atol=1e-2, rtol=1e-2))
print("seq2 allclose:", torch.allclose(direct_2.float(), packed_2.float(), atol=1e-2, rtol=1e-2))

3) Compare the actual embeddings

For Qwen3-Embedding, compare the pooled embeddings, not just hidden states. The model card uses last-token pooling for left-padded inputs and then L2 normalization. (Hugging Face)

For your packed tensor, the pooled positions are:

  • seq1 pooled token: len1 - 1
  • seq2 pooled token: len1 + len2 - 1

Then normalize and compare cosine similarity.

4) Build a stricter reference baseline

If you want to know whether the problem is semantic or only numeric, compare against a higher-precision non-fused reference:

  • torch_dtype=torch.float32
  • attn_implementation="eager" or a math-style reference path

PyTorch’s SDPA docs explicitly say the math backend is the higher-precision fallback and keeps float intermediates for half/BF16 inputs. (PyTorch Documentation)

If packed and non-packed match there but drift slightly only in BF16 + FA2, that strongly points to “expected numeric difference.”

5) Try disabling BF16 reduced-precision reductions

PyTorch documents this flag for BF16 GEMMs:

torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

If the discrepancy shrinks, that is further evidence that your issue is numeric, not semantic. (PyTorch Documentation)


Practical conclusion

For your case, I would interpret the result like this:

  • No, your code is not obviously wrong just because BF16 packed output is slightly different.
  • Yes, there is a solid reason: different layout + packed masking path + FA2 fused kernels + BF16 reduction behavior. (GitHub)
  • The right standard is close within tolerance, especially on the final pooled normalized embeddings, not exact equality of raw hidden states. (Hugging Face)

The result would become concerning if:

  • the second packed segment is much worse than the first,
  • pooled embeddings differ materially,
  • or the mismatch stays large even in an FP32 / eager reference run.

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B"
TEXTS = [
    "hi how are you this is what so that is that what is know is there any way to do this",
    "hi how can i help you today",
]
MAX_LENGTH = 512
DEVICE = "cuda"

# ----------------------------
# helpers
# ----------------------------
def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    # Qwen3-Embedding uses left padding with FA2, so the last token is the rightmost valid token.
    lengths = attention_mask.sum(dim=1) - 1
    batch_idx = torch.arange(last_hidden_states.size(0), device=last_hidden_states.device)
    return last_hidden_states[batch_idx, lengths]

def build_packed_inputs(tokenizer, texts, max_length, device):
    # Important: use the SAME truncation settings as the padded path.
    unpadded = tokenizer(
        texts,
        padding=False,
        truncation=True,
        max_length=max_length,
        return_attention_mask=False,
    )

    flat_ids = []
    flat_pos = []
    lengths = []

    for ids in unpadded["input_ids"]:
        flat_ids.extend(ids)
        flat_pos.extend(range(len(ids)))
        lengths.append(len(ids))

    packed = {
        "input_ids": torch.tensor([flat_ids], device=device),
        "position_ids": torch.tensor([flat_pos], device=device),
    }
    return packed, lengths

def compare_case(model_name, texts, dtype, attn_impl, max_length=512):
    print(f"\n=== dtype={dtype}, attn_impl={attn_impl} ===")

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
    model = AutoModel.from_pretrained(
        model_name,
        attn_implementation=attn_impl,
        torch_dtype=dtype,
    ).to(DEVICE)
    model.eval()

    # Padded inputs
    padded = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    padded = {k: v.to(DEVICE) for k, v in padded.items()}

    # Packed inputs
    packed, packed_lengths = build_packed_inputs(tokenizer, texts, max_length, DEVICE)

    with torch.inference_mode():
        out_padded = model(**padded).last_hidden_state
        out_packed = model(**packed).last_hidden_state

    # Compare per-sequence hidden states
    lengths = padded["attention_mask"].sum(dim=1).tolist()
    start = 0
    pooled_packed = []

    print("\nPer-sequence hidden-state comparison")
    for i, L in enumerate(lengths):
        # left padding => valid tokens are the last L positions
        hs_padded = out_padded[i, -L:]
        hs_packed = out_packed[0, start:start + L]
        start += L

        diff = (hs_padded.float() - hs_packed.float()).abs()
        max_abs = diff.max().item()
        mean_abs = diff.mean().item()
        ok = torch.allclose(hs_padded.float(), hs_packed.float(), atol=1e-2, rtol=1e-2)

        print(f"seq{i}: len={L:3d} | max_abs={max_abs:.6f} | mean_abs={mean_abs:.6f} | allclose={ok}")

        pooled_packed.append(hs_packed[-1])

    # Compare final pooled normalized embeddings
    pooled_padded = last_token_pool(out_padded, padded["attention_mask"])
    pooled_packed = torch.stack(pooled_packed, dim=0)

    emb_padded = F.normalize(pooled_padded.float(), p=2, dim=1)
    emb_packed = F.normalize(pooled_packed.float(), p=2, dim=1)

    cos = F.cosine_similarity(emb_padded, emb_packed, dim=1)
    emb_diff = (emb_padded - emb_packed).abs()

    print("\nFinal embedding comparison")
    for i in range(len(texts)):
        print(
            f"seq{i}: cosine={cos[i].item():.8f} | "
            f"max_abs={emb_diff[i].max().item():.8f} | "
            f"mean_abs={emb_diff[i].mean().item():.8f}"
        )

if __name__ == "__main__":
    if not torch.cuda.is_available():
        raise RuntimeError("This script requires CUDA.")

    # Main case: the one you care about
    compare_case(
        model_name=MODEL_NAME,
        texts=TEXTS,
        dtype=torch.bfloat16,
        attn_impl="flash_attention_2",
        max_length=MAX_LENGTH,
    )

    # Reference case: useful to see whether differences are only from BF16 + FA2
    compare_case(
        model_name=MODEL_NAME,
        texts=TEXTS,
        dtype=torch.float32,
        attn_impl="eager",
        max_length=MAX_LENGTH,
    )

What this script checks

It tests the two comparisons that matter:

  1. Per-sequence hidden states

    • padded path vs packed path
    • aligned correctly for left padding
  2. Final pooled embeddings

    • takes the last valid token for each sequence
    • L2-normalizes
    • reports cosine similarity and absolute differences

How to read the output

For the BF16 + FlashAttention2 block, the usual healthy pattern is:

  • max_abs on hidden states is not zero
  • allclose=True often holds with atol=1e-2, rtol=1e-2
  • pooled embedding cosine is very close to 1.0

For the FP32 + eager block, differences should usually be smaller. If FP32/eager matches closely but BF16/FA2 shows only slight drift, that points to expected numeric behavior rather than a packing bug.

Two small but important fixes compared with your original snippet

  • It uses the same truncation settings in both paths.
  • It compares the correct slices for each sequence instead of eyeballing full tensors.