How to torch.export tinyllama's prefill model

I find a way, close vmap in cause_mask.


import torch
from torch.export import export, Dim
from transformers import AutoModelForCausalLM
import transformers.masking_utils as mu
import transformers.models.llama.modeling_llama as modeling_llama
import os

# ----------------------------------------------------------------------

# ----------------------------------------------------------------------
# ----------------------------------------------------------------------
def causal_mask_no_vmap(*, 
                        input_embeds=None,
                        attention_mask=None,
                        position_ids=None,
                        dtype=None,
                        device=None,
                        **_):

    if position_ids is not None: 
        bsz, q_len = position_ids.shape
        device = device or position_ids.device
    elif input_embeds is not None:
        bsz, q_len = input_embeds.shape[:2]
        device = device or input_embeds.device
    elif attention_mask is not None:
        bsz, q_len = attention_mask.shape
        device = device or attention_mask.device
    else:
        raise RuntimeError("Cannot infer sequence length for causal mask")

    k_len = q_len                          # prefill: K == Q
    dtype = dtype or torch.float32
    mask = torch.ones((q_len, k_len), dtype=dtype, device=device).tril()
    return mask.view(1, 1, q_len, k_len).expand(bsz, -1, -1, -1)  # [B,1,Q,K]

mu.create_causal_mask = causal_mask_no_vmap
modeling_llama.create_causal_mask = causal_mask_no_vmap


model = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    trust_remote_code=True,
).eval()
model.generation_config.use_cache = True
model.config.attn_implementation = model.config._attn_implementation = "eager"


class PrefillWrapper(torch.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
    def forward(self, input_ids):      # [1, T]
        B, T = input_ids.shape         # B = 1
        device = input_ids.device
        pos  = torch.arange(T, device=device).unsqueeze(0)
        mask = torch.ones((1, T), dtype=torch.bool, device=device)
        return self.m(
            input_ids=input_ids,
            position_ids=pos,
            attention_mask=mask,
            past_key_values=None,
            use_cache=True,
        ).logits                        # [1, T, vocab]

wrapper = PrefillWrapper(model)


example = torch.tensor([[1, 2, 3]], dtype=torch.long)


seq = Dim("seq", min=1, max=1024)
dyn_shapes = {"input_ids": {1: seq}}

with torch.no_grad():
    ep = export(wrapper, (example,), dynamic_shapes=dyn_shapes)

print(ep.graph)
1 Like