Demo code:
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
# Init input id
input_ids = torch.tensor([[1, 2, 3, 4]], dtype=torch.int64)
# One forward pass
output = model(input_ids)
past_key_values = output.past_key_values
# Add predicted token
next_token = output.logits[:, -1].argmax(1, keepdim=True)
input_ids = torch.cat((input_ids, next_token), dim=1)
# Add some custom tokens
additional_token = torch.zeros(1, 3, dtype=torch.int64)
input_ids = torch.cat((input_ids, additional_token), dim=1)
# Generate
model(input_ids, past_key_values=past_key_values) # Works
model.generate(input_ids, max_length=30, past_key_values=past_key_values) # Fails
Several questions:
- Even if model(…) does not raise an exception, does it work? Or is there a silent bug?
- If the model does support partial
past_key_values
, whymodel.generate
doesn’t?
Thanks!
transformers
: v4.29
Python: v3.10.10