im trying to use past_key_values for speed up the inference
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
torch.set_default_device("cuda")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
model.to("cuda")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
seq = torch.tensor([1, 2, 3, 4, 5])
original_out = model(input_ids=seq).logits
seq2 = torch.tensor([1, 2, 3])
key_values = model(input_ids=seq2, use_cache=True).past_key_values
new_seq = torch.tensor([4, 5])
magic = model(input_ids=new_seq, past_key_values=key_values).logits
print(torch.equal(original_out[-1, :], magic[-1, :]))
but this returns false
i expected return true