It seems to happen with anything other than torch.float32, and it seems to be particularly noticeable with torch.bfloat16. There are also some who point out that it is a unique problem with Qwen 2.5.
With bfloat16, Attention may also be suspicious.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# load model and tokenizezr
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
).eval().to(torch.float32) # if bfloat16, it causes inconsistency
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("model.dtype: ", model.dtype)
print("model.device: ", model.device)
# input texts
texts = ['a', 'b', 'c']
# tokenize
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(model.device)
# get inputs_embeds
with torch.no_grad():
inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
# get attention_mask and position_ids
attention_mask = inputs.attention_mask
position_ids = torch.arange(inputs.input_ids.shape[1], device=model.device).unsqueeze(0).expand(inputs.input_ids.shape[0], -1)
# batch
with torch.no_grad():
output_batch = model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids
).logits[0]
# single
with torch.no_grad():
output_single = model(
inputs_embeds=inputs_embeds[0].unsqueeze(0),
attention_mask=attention_mask[0].unsqueeze(0),
position_ids=position_ids[0].unsqueeze(0)
).logits[0]
# check consistency
is_close = torch.allclose(output_batch, output_single, atol=1e-5, rtol=1e-3)
print("consistent?: ", is_close)
print("batch: ", output_batch)
print("single: ", output_single)