LLama 3.1 torch.compile & static cache

I want to speed up the inference time of LLama 3.1 8b Instruct on my local computer. I have 1 NVIDIA GeForce RTX 4080
I am running the code on Ubuntu 22.04

transformers             4.46.3
tokenizers               0.20.3
torch                    2.5.1

I got this error in the generate function: Unsupported: torch. op returned non-Tensor device call_function *

blow is the code to reproduce the error

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
import torch
nf4_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16)

checkpoint = "meta-llama/Llama-3.1-8B-Instruct"
llm_model=AutoModelForCausalLM.from_pretrained(checkpoint,
                                    quantization_config=nf4_config,
                                    low_cpu_mem_usage=True,
                                    attn_implementation="flash_attention_2",
                                    device_map="auto"
                                    )

tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
llm_model.forward = torch.compile(
llm_model.forward, mode="reduce-overhead", fullgraph=True)

messages = [{'role': 'system', 'content': 'You are an annotator for extracting verbs from english sentences'},
            {'role': 'user', 'content': 'English sentences:\n```I like pizaa. I would like an ice cream```.The output should be a valid JSON format'},
            {'role': 'assistant', 'content': '{"verbs":[like, would like]}'},
            {'role': 'user', 'content': 'English sentences:\n```I enjoy watching football games```. The output should be a valid JSON format'}]

tokenizer.chat_template="""{% set begin_of_text = '<|begin_of_text|>' %}{% set start_header_id = '<|start_header_id|>' %}{% set end_header_id = '<|end_header_id|>' %}{% set eot_id = '<|eot_id|>' %}{{ begin_of_text }}{% if messages[0]['role'] == 'system' %}{{ start_header_id }}system{{ end_header_id }}{{ messages[0]['content'] }}{{ eot_id }}{% set loop_messages = messages[1:] %}{% else %}{% set loop_messages = messages %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{{ start_header_id }}{{ message['role'] }}{{ end_header_id }}{{ message['content'] }}{{ eot_id }}{% endfor %}{{ start_header_id }}assistant{{ end_header_id }}"""
prompts = [tokenizer.apply_chat_template(messages, tokenize=False)]*20
import time
tokenizer_config={"padding":True}
generation_config = {"max_new_tokens": 512, "use_cache": True, "do_sample": True,
                    "temperature": 0.001, "top_p": 0.9,
                    "cache_implementation": "static"}

outputs = []
start = time.perf_counter()
with torch.no_grad():
    tokens_nums = 0
    for batch in prompts:
        inputs_model = tokenizer(
            batch, return_tensors="pt", **tokenizer_config)
        inputs_model.to(llm_model.device)

        model_input_length = len(inputs_model[0])
        output_encode = llm_model.generate(**inputs_model, **generation_config,
                                                pad_token_id=tokenizer.eos_token_id)

        output_encode = output_encode[:, model_input_length:]
        tokens_nums += output_encode.shape[1]
        output = tokenizer.batch_decode(
            output_encode, skip_special_tokens=True)
        outputs.extend(output)
end = time.perf_counter()

print("Tokens speed:\s", tokens_nums/(end-start))
1 Like

I think it’s because of torch.compile.

It looks like that is not fixed yet.

1 Like