Hey, I am trying to do a forward() pass of the GPT-2 model with and without caching of past values and observed that the logits are slightly different. Is this to be expected? I highly appreciate it if someone could help me with this (tagging @patrickvonplaten for his expertise)
Code snippet for an MWE below (Check the last assert statement which fails)
from transformers import GPT2LMHeadModel
import torch
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
with torch.no_grad():
#########################################################################
# with forward and no caching of past
# left padded to size of 5
# step 0
input_ids = torch.tensor([50256, 50256, 50256, 50256, 2]).reshape(1, -1)
attention_mask = torch.tensor([0, 0, 0, 0, 1]).reshape(1, -1)
position_ids = torch.tensor([1, 1, 1, 1, 0]).reshape(1, -1)
gen_outputs = model(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True)
no_cache_0_next_token_logits = gen_outputs.logits[0, -1, :].clone()
# step 1 - input grown by 1
input_ids = torch.tensor([50256, 50256, 50256, 50256, 2, 5]).reshape(1, -1)
attention_mask = torch.tensor([0, 0, 0, 0, 1, 1]).reshape(1, -1)
position_ids = torch.tensor([1, 1, 1, 1, 0, 1]).reshape(1, -1)
gen_outputs = model(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True)
no_cache_1_next_token_logits = gen_outputs.logits[0, -1, :].clone()
########################################################################
# with forward with caching
# left padded to size of 5
# step 0
input_ids = torch.tensor([50256, 50256, 50256, 50256, 2]).reshape(1, -1)
model_kwargs = {
"attention_mask": torch.tensor([0, 0, 0, 0, 1]).reshape(1, -1)
}
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
gen_outputs = model(**model_inputs,
return_dict=True)
cache_0_next_token_logits = gen_outputs.logits[0, -1, :].clone()
assert torch.equal(cache_0_next_token_logits,
no_cache_0_next_token_logits) == True
model_kwargs = model._update_model_kwargs_for_generation(
gen_outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
# step 1 - input grown by 1
input_ids = torch.tensor([50256, 50256, 50256, 50256, 2, 5]).reshape(1, -1)
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
gen_outputs = model(**model_inputs,
return_dict=True)
cache_1_next_token_logits = gen_outputs.logits[0, -1, :].clone()
assert torch.equal(cache_1_next_token_logits,
no_cache_1_next_token_logits) == True