Why do the value of logits change depending on whether samples are batched or not?

I’m noticing for models like BERT that the value of the model output logits change depending on how samples are batched & padded.

An example

import torch 
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased")

# arbitrarily add a padding token that should get ignored in SA by the `attention_mask`
a = "hey this is a thing"
b = "this is another thing but its wayyy longer"

inputs = tokenizer([a,b], padding = "max_length", max_length = len(tokenizer(b)['input_ids'])+1, return_tensors = 'pt')
logits = model(**inputs)['logits'] # torch.Size([2, 13, 30522])


# now individually get logits for each of the sequences 
logits1 = model(**tokenizer(a, return_tensors = 'pt'))['logits'] # torch.Size([1, 7, 30522])
logits2 = model(**tokenizer(b, return_tensors = 'pt'))['logits'] # torch.Size([1, 12, 30522])

# check if they agree (mind the slicing :p)
print(torch.allclose(logits[0][:logits1.shape[1]].unsqueeze(0), logits1))
print(torch.allclose(logits[1][:logits2.shape[1]].unsqueeze(0), logits2))

>> False
>> False 

Now if you just pad to the longest sequence and rerun everything you get

# .... 
inputs = tokenizer([a,b], padding = True, return_tensors = 'pt')

# same as above 

print(torch.allclose(logits[0][:logits1.shape[1]].unsqueeze(0), logits1))
print(torch.allclose(logits[1][:logits2.shape[1]].unsqueeze(0), logits2))

>> False
>> True 

Does anyone know why messing with how samples are padded or batched affects the actual model outputs when intuitively the attention_mask is supposed to ignore the padding tokens in SA?

Have you found an answer so far? @nnethercott

I’m looking into a similar situation with the model "meta-llama/Meta-Llama-3-8B-Instruct", where the logits change for the same prompt when its batched. Below is pseudocode where I would expect the same logits to come out (just stacked), but surprisingly, it does not.

logits1 = model(**tokenizer([a]), ...)
logits2 = model(**tokenizer([a, a]), ...)
logits3 = model(**tokenizer([a, a, a]), ...)

I’ve created a full reproducible example here: An example with Llama3, where batching the input prompts modifies the logits that come out. · GitHub

EDIT: I found that the differences in logits dissapear when I increase the floating point precision from bfloat16 to float32. However, this also meant offloading part of the model to the CPU.