I need custom generation logic for my use-case, and it appears that the only way I can achieve this is by re-writing the generate method.
I have started by attempting to create a simple, crude implementation which feeds the input ids through the model to retrieve next token logits, samples the logits, and appends the new token to the input ids in a loop.
However, I have noticed that my implementation is significantly slower than if I use the built-in model.generate() method. Here is a snippet of the two implementations:
import time
torch.manual_seed(0)
# manual
input_ids = tokenizer("test", return_tensors='pt', return_token_type_ids=False).to(0)["input_ids"] # tokenise the prompt
input_tokens = len(input_ids[0])
test_tokens = 10
start_time = time.time()
for _ in range(test_tokens):
model_output = model(input_ids=input_ids,attention_mask=torch.ones_like(input_ids)) # feed the current generation through the model
logits = model_output.logits[:, -1, :] # get the next token logits
next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) # sample the logits to get the next token
input_ids = torch.cat([input_ids, next_token], dim=-1) # add the new token to the current generation
end_time = time.time()
print(f"manual: {end_time-start_time}")
# auto
torch.manual_seed(0)
input_ids = tokenizer("test", return_tensors='pt', return_token_type_ids=False).to(0) # tokenise the prompt
start_time = time.time()
output = model.generate(
**input_ids,
do_sample=True,
max_length=input_tokens+test_tokens,
)
end_time = time.time()
print(f"auto: {end_time-start_time}")
The time to generate 10 tokens with my implementation is 6.08 seconds, and 1.37 seconds with the model.generate() method.
I would appreciate some pointers as to what differs in the model.generate() method to make it so much faster.
Thanks