I’m working with an LLM model and would like to export it using torch.export.export
to obtain the PyTorch IR.
As I understand, exporting an LLM typically involves two separate graphs: prefill and decode. I’ve already successfully exported the decode graph.
import torch, transformers
from torch.export import Dim
from transformers.cache_utils import StaticCacheConfig
from transformers.integrations.executorch import convert_and_export_with_cache
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.float16,
device_map="cpu",
)
#batch=1, if batch is dynamic, it can't export
model.generation_config.cache_implementation = "static"
model.generation_config.cache_config = StaticCacheConfig(
batch_size=1, max_cache_len=1024, device="cpu"
)
# only seq is dynamic
seq = Dim("seq", min=1, max=1024)
dynamic_shapes = {
"input_ids": {1: seq},
"cache_position": {0: seq},
}
exported_program = convert_and_export_with_cache(
model,
example_input_ids = torch.randint(0, 32000, (1, 4)), # 1×4
example_cache_position = torch.arange(4),
dynamic_shapes = dynamic_shapes,
)
print("export scuess!", exported_program)
However, I’m having trouble exporting the prefill graph because the model internally uses vmap
, which causes torch.export
to fail. Or I use it with a wrong way?
What would be the recommended way to export the prefill graph in this case?
import torch
from torch.export import export, Dim
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
trust_remote_code=True
).eval()
model.generation_config.use_cache = True
class PrefillWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids: torch.Tensor):
# input_ids: [1, T]
B, T = input_ids.shape
position_ids = torch.arange(T, dtype=torch.long).unsqueeze(0)
outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=None,
past_key_values=None,
use_cache=True,
)
return outputs.logits
wrapper = PrefillWrapper(model)
example_input = torch.tensor([[1, 2, 3]], dtype=torch.long) # shape = [1, T]
dynamic_shapes = {
"input_ids": {1: Dim("seq", min=1, max=2048)}
}
ep = export(wrapper, (example_input,), dynamic_shapes=dynamic_shapes)
with open("prefill_dynamic_seq.ir", "w") as f:
f.write(str(ep.module()))