I find a way, close vmap in cause_mask.
import torch
from torch.export import export, Dim
from transformers import AutoModelForCausalLM
import transformers.masking_utils as mu
import transformers.models.llama.modeling_llama as modeling_llama
import os
# ----------------------------------------------------------------------
# ----------------------------------------------------------------------
# ----------------------------------------------------------------------
def causal_mask_no_vmap(*,
input_embeds=None,
attention_mask=None,
position_ids=None,
dtype=None,
device=None,
**_):
if position_ids is not None:
bsz, q_len = position_ids.shape
device = device or position_ids.device
elif input_embeds is not None:
bsz, q_len = input_embeds.shape[:2]
device = device or input_embeds.device
elif attention_mask is not None:
bsz, q_len = attention_mask.shape
device = device or attention_mask.device
else:
raise RuntimeError("Cannot infer sequence length for causal mask")
k_len = q_len # prefill: K == Q
dtype = dtype or torch.float32
mask = torch.ones((q_len, k_len), dtype=dtype, device=device).tril()
return mask.view(1, 1, q_len, k_len).expand(bsz, -1, -1, -1) # [B,1,Q,K]
mu.create_causal_mask = causal_mask_no_vmap
modeling_llama.create_causal_mask = causal_mask_no_vmap
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
trust_remote_code=True,
).eval()
model.generation_config.use_cache = True
model.config.attn_implementation = model.config._attn_implementation = "eager"
class PrefillWrapper(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.m = m
def forward(self, input_ids): # [1, T]
B, T = input_ids.shape # B = 1
device = input_ids.device
pos = torch.arange(T, device=device).unsqueeze(0)
mask = torch.ones((1, T), dtype=torch.bool, device=device)
return self.m(
input_ids=input_ids,
position_ids=pos,
attention_mask=mask,
past_key_values=None,
use_cache=True,
).logits # [1, T, vocab]
wrapper = PrefillWrapper(model)
example = torch.tensor([[1, 2, 3]], dtype=torch.long)
seq = Dim("seq", min=1, max=1024)
dyn_shapes = {"input_ids": {1: seq}}
with torch.no_grad():
ep = export(wrapper, (example,), dynamic_shapes=dyn_shapes)
print(ep.graph)