Why if use cache in gpt2 model from transformers , the logits are different if i do a forward pass from scratch

im trying to use past_key_values for speed up the inference

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
seq = torch.tensor([1, 2, 3, 4, 5])
original_out = model(input_ids=seq).logits
seq2 = torch.tensor([1, 2, 3])
key_values = model(input_ids=seq2, use_cache=True).past_key_values
new_seq = torch.tensor([4, 5])
magic = model(input_ids=new_seq, past_key_values=key_values).logits
print(torch.equal(original_out[-1, :], magic[-1, :]))

but this returns false

i expected return true

Hey, all that is the issue here is you used torch.equal
floating points get little errors in them, and should not be considered perfectly deterministic. This is expected, normal, and usually totally fine.
However it does mean strict equality checks will often fail when you would want them to pass.
For this reason torch.allclose exists. you may need to set atol and rtol high in some circumstances but here it works fine. torch.allclose just check that all the values are close.

trying your code with allclose, it returns true :slight_smile: