Implementing StoppingCriteria for Code Generating Transformers

I’m tryting to get stats of the inference time of different code-completion models on the HumanEval dataset. Since timing is a crucial part of this project, I don’t want to time the model when it generates irrelevant tokens. Thus, I hope to implement StoppingCriteria on the code-completion models, namely models from the Codegen, Code LLAMA, and WizardCoder families.

Currently, when the model generates the full answer but hasn’t reached the max number of new tokens (here I set it to 200), it might end with an <|endoftext|> token, but more often it would generate double new lines and continue generating irrelevant text. This largely affects the timing.

Therefore, I hope the generation can stop when it first encounters a "\n\n" token, or two consecutive "\n" tokens (["\n", "\n"]). How can I implement this?

To simplify the testing case, here I set the batch size to 1 for each generation. I’d appreciate if it also works when I set num_return_sequences to k, so I can get pass@k stats.

The environment is pulled on 08-29-2023 from the latest huggingface transformers main branch, v4.33. The github repo is provided below: GitHub - huggingface/transformers: 🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

The Python environment should be above 3.8.0. To test with various model checkpoints, use the checkpoint names are given in the comments. I recommend to test with smaller models if you don’t have enough GPU VRAM.

from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
import time
import argparse
import torch

parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default="Salesforce/codegen-2B-mono", help="Model path")
FLAGS = parser.parse_args()

# WizardCoder Family
# WizardLM/WizardCoder-Python-34B-V1.0
# WizardLM/WizardCoder-Python-13B-V1.0
# WizardLM/WizardCoder-15B-V1.0
# WizardLM/WizardCoder-3B-V1.0
# WizardLM/WizardCoder-1B-V1.0

# Code LLAMA 2 Family
# codellama/CodeLlama-7b-hf
# codellama/CodeLlama-13b-hf
# codellama/CodeLlama-34b-hf

# Salesforce Codegen Family
# Salesforce/codegen-350M-mono
# Salesforce/codegen-2B-mono
# Salesforce/codegen-6B-mono
# Salesforce/codegen-16B-mono

stop_words = ["\n\n"]
# HumanEval Q0
prompt_0 = "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n    given threshold.\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n    False\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n    True\n    \"\"\"\n"
# HumanEval Q31
prompt_31 = "\n\ndef is_prime(n):\n    \"\"\"Return true if a given number is prime, and false otherwise.\n    >>> is_prime(6)\n    False\n    >>> is_prime(101)\n    True\n    >>> is_prime(11)\n    True\n    >>> is_prime(13441)\n    True\n    >>> is_prime(61)\n    True\n    >>> is_prime(4)\n    False\n    >>> is_prime(1)\n    False\n    \"\"\"\n"
# HumanEval Q35
prompt_35 = "\n\ndef max_element(l: list):\n    \"\"\"Return maximum element in the list.\n    >>> max_element([1, 2, 3])\n    3\n    >>> max_element([5, 3, -5, 2, -3, 3, 9, 0, 123, 1, -10])\n    123\n    \"\"\"\n"
# HumanEval Q161
prompt_161 = "\ndef solve(s):\n    \"\"\"You are given a string s.\n    if s[i] is a letter, reverse its case from lower to upper or vise versa, \n    otherwise keep it as it is.\n    If the string contains no letters, reverse the string.\n    The function should return the resulted string.\n    Examples\n    solve(\"1234\") = \"4321\"\n    solve(\"ab\") = \"AB\"\n    solve(\"#a@C\") = \"#A@c\"\n    \"\"\"\n"


def main(args):
    # Initialize model and tokenizer
    checkpoint = args.checkpoint
    tokenizer = AutoTokenizer.from_pretrained(checkpoint, device_map="auto")
    start_load_model = time.time()
    model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")
    print(f"Time to load model {checkpoint} is {time.time() - start_load_model}")
    
    # Generate the selcted prompts
    for prompt in [prompt_0, prompt_31, prompt_35, prompt_161]:
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        start_generating = time.time()
        generated_ids = model.generate(
            input_ids,
            use_cache = True,
            pad_token_id = tokenizer.eos_token_id,
            max_new_tokens = 200,
            do_sample = True,
            temperature = 0.8,
            num_beams=1,
            # stopping_criteria=stopping_criteria,
        )
        generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
        decoded_list = []
        for ids in generated_ids[0]:
            word = tokenizer.decode(int(ids))
            decoded_list.append(word)
        generated_len = len(decoded_list) - len(input_ids[0])
        
        # Print outputs
        print(f"Time to generate is {time.time() - start_generating}")
        print(f"per token time is {(time.time()-start_generating)/generated_len}")
        print(f"decoded_list is {decoded_list[:generated_len]}")
        prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids
        prompt = tokenizer.decode(prompt_ids[0])
        print(f"\ngenerated_text is:\n{generated_text[0]}")

if __name__== "__main__":
    main(FLAGS)

Hi @BoyuanJackchen

Chck this answer: Implimentation of Stopping Criteria List - #9 by berkecr

It might help you

Regards

Nuno