How to torch.export tinyllama's prefill model

I’m working with an LLM model and would like to export it using torch.export.export to obtain the PyTorch IR.

As I understand, exporting an LLM typically involves two separate graphs: prefill and decode. I’ve already successfully exported the decode graph.

import torch, transformers
from torch.export import Dim
from transformers.cache_utils import StaticCacheConfig
from transformers.integrations.executorch import convert_and_export_with_cache
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    torch_dtype=torch.float16,
    device_map="cpu",
)

#batch=1, if batch is dynamic, it can't export 
model.generation_config.cache_implementation = "static"
model.generation_config.cache_config = StaticCacheConfig(
    batch_size=1, max_cache_len=1024, device="cpu"
)

# only seq is dynamic
seq = Dim("seq", min=1, max=1024)
dynamic_shapes = {
    "input_ids":      {1: seq},
    "cache_position": {0: seq},
}

exported_program = convert_and_export_with_cache(
    model,
    example_input_ids      = torch.randint(0, 32000, (1, 4)),  # 1×4
    example_cache_position = torch.arange(4),
    dynamic_shapes         = dynamic_shapes,
)

print("export scuess!", exported_program)

However, I’m having trouble exporting the prefill graph because the model internally uses vmap, which causes torch.export to fail. Or I use it with a wrong way?

What would be the recommended way to export the prefill graph in this case?

import torch
from torch.export import export, Dim
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    trust_remote_code=True
).eval()

model.generation_config.use_cache = True

class PrefillWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids: torch.Tensor):
        # input_ids: [1, T]
        B, T = input_ids.shape
        position_ids = torch.arange(T, dtype=torch.long).unsqueeze(0) 
        outputs = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=None,
            past_key_values=None,
            use_cache=True,
        )
        return outputs.logits  


wrapper = PrefillWrapper(model)


example_input = torch.tensor([[1, 2, 3]], dtype=torch.long)  # shape = [1, T]

dynamic_shapes = {
    "input_ids": {1: Dim("seq", min=1, max=2048)} 
}


ep = export(wrapper, (example_input,), dynamic_shapes=dynamic_shapes)


with open("prefill_dynamic_seq.ir", "w") as f:
    f.write(str(ep.module()))
1 Like

It’d be easy if this worked…

#ep = export(wrapper, (example_input,), dynamic_shapes=dynamic_shapes)
ep = export(wrapper, (example_input,), dynamic_shapes=dynamic_shapes, strict=False)

I find a way, close vmap in cause_mask.


import torch
from torch.export import export, Dim
from transformers import AutoModelForCausalLM
import transformers.masking_utils as mu
import transformers.models.llama.modeling_llama as modeling_llama
import os

# ----------------------------------------------------------------------

# ----------------------------------------------------------------------
# ----------------------------------------------------------------------
def causal_mask_no_vmap(*, 
                        input_embeds=None,
                        attention_mask=None,
                        position_ids=None,
                        dtype=None,
                        device=None,
                        **_):

    if position_ids is not None: 
        bsz, q_len = position_ids.shape
        device = device or position_ids.device
    elif input_embeds is not None:
        bsz, q_len = input_embeds.shape[:2]
        device = device or input_embeds.device
    elif attention_mask is not None:
        bsz, q_len = attention_mask.shape
        device = device or attention_mask.device
    else:
        raise RuntimeError("Cannot infer sequence length for causal mask")

    k_len = q_len                          # prefill: K == Q
    dtype = dtype or torch.float32
    mask = torch.ones((q_len, k_len), dtype=dtype, device=device).tril()
    return mask.view(1, 1, q_len, k_len).expand(bsz, -1, -1, -1)  # [B,1,Q,K]

mu.create_causal_mask = causal_mask_no_vmap
modeling_llama.create_causal_mask = causal_mask_no_vmap


model = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    trust_remote_code=True,
).eval()
model.generation_config.use_cache = True
model.config.attn_implementation = model.config._attn_implementation = "eager"


class PrefillWrapper(torch.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
    def forward(self, input_ids):      # [1, T]
        B, T = input_ids.shape         # B = 1
        device = input_ids.device
        pos  = torch.arange(T, device=device).unsqueeze(0)
        mask = torch.ones((1, T), dtype=torch.bool, device=device)
        return self.m(
            input_ids=input_ids,
            position_ids=pos,
            attention_mask=mask,
            past_key_values=None,
            use_cache=True,
        ).logits                        # [1, T, vocab]

wrapper = PrefillWrapper(model)


example = torch.tensor([[1, 2, 3]], dtype=torch.long)


seq = Dim("seq", min=1, max=1024)
dyn_shapes = {"input_ids": {1: seq}}

with torch.no_grad():
    ep = export(wrapper, (example,), dynamic_shapes=dyn_shapes)

print(ep.graph)
1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.