GPT-2 Forward w/ and w/o caching of past gives different results

Hey, I am trying to do a forward() pass of the GPT-2 model with and without caching of past values and observed that the logits are slightly different. Is this to be expected? I highly appreciate it if someone could help me with this (tagging @patrickvonplaten for his expertise)

Code snippet for an MWE below (Check the last assert statement which fails)

from transformers import GPT2LMHeadModel
import torch

model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()

with torch.no_grad():
    #########################################################################
    # with forward and no caching of past
    # left padded to size of 5
    # step 0
    input_ids = torch.tensor([50256, 50256, 50256, 50256, 2]).reshape(1, -1)
    attention_mask = torch.tensor([0, 0, 0, 0, 1]).reshape(1, -1)
    position_ids = torch.tensor([1, 1, 1, 1, 0]).reshape(1, -1)

    gen_outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        return_dict=True)
    no_cache_0_next_token_logits = gen_outputs.logits[0, -1, :].clone()

    # step 1 - input grown by 1
    input_ids = torch.tensor([50256, 50256, 50256, 50256, 2, 5]).reshape(1, -1)
    attention_mask = torch.tensor([0, 0, 0, 0, 1, 1]).reshape(1, -1)
    position_ids = torch.tensor([1, 1, 1, 1, 0, 1]).reshape(1, -1)
    gen_outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        return_dict=True)
    no_cache_1_next_token_logits = gen_outputs.logits[0, -1, :].clone()

    ########################################################################
    # with forward with caching
    # left padded to size of 5
    # step 0
    input_ids = torch.tensor([50256, 50256, 50256, 50256, 2]).reshape(1, -1)
    model_kwargs = {
        "attention_mask": torch.tensor([0, 0, 0, 0, 1]).reshape(1, -1)
    }

    model_inputs = model.prepare_inputs_for_generation(
        input_ids, **model_kwargs)
    gen_outputs = model(**model_inputs,
                        return_dict=True)
    cache_0_next_token_logits = gen_outputs.logits[0, -1, :].clone()
    assert torch.equal(cache_0_next_token_logits,
                       no_cache_0_next_token_logits) == True
    model_kwargs = model._update_model_kwargs_for_generation(
        gen_outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
    )

    # step 1 - input grown by 1
    input_ids = torch.tensor([50256, 50256, 50256, 50256, 2, 5]).reshape(1, -1)
    model_inputs = model.prepare_inputs_for_generation(
        input_ids, **model_kwargs)
    gen_outputs = model(**model_inputs,
                        return_dict=True)
    cache_1_next_token_logits = gen_outputs.logits[0, -1, :].clone()
    assert torch.equal(cache_1_next_token_logits,
                       no_cache_1_next_token_logits) == True