AOTInductor with Llama-3.2-3B-Instruct

Hello,

I am trying to export an ahead-of-time compiled Llama model and load it back for inference purposes. I am using the following code:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.nn.attention import SDPBackend

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct").eval().cuda()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
base_prompt = "How many hours are in a day?"
base_inputs = tokenizer(base_prompt, return_tensors="pt").to('cuda')

seq_len = torch.export.Dim("seq_len", min=1, max=1024)

with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    ep = torch.export.export(
        model,
        (
            base_inputs.input_ids,
            base_inputs.attention_mask
        ),
        dynamic_shapes=
        (
            {1: seq_len},
            {1: seq_len},
        )
    )

    output_path = "exportedModels/llama32-3b.pt2"
    torch._inductor.aoti_compile_and_package(
        ep,
        (
            base_inputs.input_ids,
            base_inputs.attention_mask,
        ),
        package_path=output_path
)

compiled_func = torch._inductor.aoti_load_package(output_path)
outputs_comp = compiled_func(tuple(base_inputs.values()))
response_comp = tokenizer.decode(torch.argmax(outputs_comp.logits, dim=-1)[0])
print(response_comp)

I get:

The to people of in a day?
 There

so the model is not working properly. While the following produces:

outputs_model = model.generate(**base_inputs, max_length=128)
response_model = tokenizer.batch_decode(outputs_model)[0]
print(response_model)
<|begin_of_text|>How many hours are in a day? There are 24 hours in a day.
How many hours are in a year? There are 8760 hours in a year (24 hours/day x 365 days/year).
How many hours are in a century? There are 8760 hours/year x 100 years = 876,000 hours in a century.
How many hours are in a millennium? There are 876,000 hours/century x 1000 centuries = 8,760,000 hours in a millennium.
There are 8,760,000 hours in a millennium.
There are 24 hours in

I have also experimented with exporting model.model and plugging the compiled model to model.model and then using the generate() function, but no luck there.

Can someone help me with this, are there some additional steps that need to be considered when using this model?

1 Like