I’m tryting to get stats of the inference time of different code-completion models on the HumanEval dataset. Since timing is a crucial part of this project, I don’t want to time the model when it generates irrelevant tokens. Thus, I hope to implement StoppingCriteria on the code-completion models, namely models from the Codegen, Code LLAMA, and WizardCoder families.
Currently, when the model generates the full answer but hasn’t reached the max number of new tokens (here I set it to 200), it might end with an <|endoftext|>
token, but more often it would generate double new lines and continue generating irrelevant text. This largely affects the timing.
Therefore, I hope the generation can stop when it first encounters a "\n\n"
token, or two consecutive "\n"
tokens (["\n", "\n"]
). How can I implement this?
To simplify the testing case, here I set the batch size to 1 for each generation. I’d appreciate if it also works when I set num_return_sequences to k, so I can get pass@k stats.
The environment is pulled on 08-29-2023 from the latest huggingface transformers main branch, v4.33. The github repo is provided below: GitHub - huggingface/transformers: 🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
The Python environment should be above 3.8.0. To test with various model checkpoints, use the checkpoint names are given in the comments. I recommend to test with smaller models if you don’t have enough GPU VRAM.
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
import time
import argparse
import torch
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default="Salesforce/codegen-2B-mono", help="Model path")
FLAGS = parser.parse_args()
# WizardCoder Family
# WizardLM/WizardCoder-Python-34B-V1.0
# WizardLM/WizardCoder-Python-13B-V1.0
# WizardLM/WizardCoder-15B-V1.0
# WizardLM/WizardCoder-3B-V1.0
# WizardLM/WizardCoder-1B-V1.0
# Code LLAMA 2 Family
# codellama/CodeLlama-7b-hf
# codellama/CodeLlama-13b-hf
# codellama/CodeLlama-34b-hf
# Salesforce Codegen Family
# Salesforce/codegen-350M-mono
# Salesforce/codegen-2B-mono
# Salesforce/codegen-6B-mono
# Salesforce/codegen-16B-mono
stop_words = ["\n\n"]
# HumanEval Q0
prompt_0 = "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n"
# HumanEval Q31
prompt_31 = "\n\ndef is_prime(n):\n \"\"\"Return true if a given number is prime, and false otherwise.\n >>> is_prime(6)\n False\n >>> is_prime(101)\n True\n >>> is_prime(11)\n True\n >>> is_prime(13441)\n True\n >>> is_prime(61)\n True\n >>> is_prime(4)\n False\n >>> is_prime(1)\n False\n \"\"\"\n"
# HumanEval Q35
prompt_35 = "\n\ndef max_element(l: list):\n \"\"\"Return maximum element in the list.\n >>> max_element([1, 2, 3])\n 3\n >>> max_element([5, 3, -5, 2, -3, 3, 9, 0, 123, 1, -10])\n 123\n \"\"\"\n"
# HumanEval Q161
prompt_161 = "\ndef solve(s):\n \"\"\"You are given a string s.\n if s[i] is a letter, reverse its case from lower to upper or vise versa, \n otherwise keep it as it is.\n If the string contains no letters, reverse the string.\n The function should return the resulted string.\n Examples\n solve(\"1234\") = \"4321\"\n solve(\"ab\") = \"AB\"\n solve(\"#a@C\") = \"#A@c\"\n \"\"\"\n"
def main(args):
# Initialize model and tokenizer
checkpoint = args.checkpoint
tokenizer = AutoTokenizer.from_pretrained(checkpoint, device_map="auto")
start_load_model = time.time()
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")
print(f"Time to load model {checkpoint} is {time.time() - start_load_model}")
# Generate the selcted prompts
for prompt in [prompt_0, prompt_31, prompt_35, prompt_161]:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start_generating = time.time()
generated_ids = model.generate(
input_ids,
use_cache = True,
pad_token_id = tokenizer.eos_token_id,
max_new_tokens = 200,
do_sample = True,
temperature = 0.8,
num_beams=1,
# stopping_criteria=stopping_criteria,
)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
decoded_list = []
for ids in generated_ids[0]:
word = tokenizer.decode(int(ids))
decoded_list.append(word)
generated_len = len(decoded_list) - len(input_ids[0])
# Print outputs
print(f"Time to generate is {time.time() - start_generating}")
print(f"per token time is {(time.time()-start_generating)/generated_len}")
print(f"decoded_list is {decoded_list[:generated_len]}")
prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids
prompt = tokenizer.decode(prompt_ids[0])
print(f"\ngenerated_text is:\n{generated_text[0]}")
if __name__== "__main__":
main(FLAGS)