Does model supports partial `past_key_values`?

Demo code:

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")

# Init input id
input_ids = torch.tensor([[1, 2, 3, 4]], dtype=torch.int64)

# One forward pass
output = model(input_ids)
past_key_values = output.past_key_values

# Add predicted token
next_token = output.logits[:, -1].argmax(1, keepdim=True)
input_ids = torch.cat((input_ids, next_token), dim=1)

# Add some custom tokens
additional_token = torch.zeros(1, 3, dtype=torch.int64)
input_ids = torch.cat((input_ids, additional_token), dim=1)

# Generate
model(input_ids, past_key_values=past_key_values)  # Works
model.generate(input_ids, max_length=30, past_key_values=past_key_values)  # Fails

Several questions:

  • Even if model(…) does not raise an exception, does it work? Or is there a silent bug?
  • If the model does support partial past_key_values, why model.generate doesn’t?

Thanks!

transformers: v4.29
Python: v3.10.10