I got OOM when exporting a large model to ONNX. I wonder how Optimum handles this issue.
Here are my settings
The code works for a smaller model with fewer parameters, so the error is due to the model size.
Not able to export the model on the CPU because of fp16.
The pure model takes 20GB of CUDA memory, and the total GPU capacity is 80G. ( It seems 10x memory will be consumed for exporting a small model. 2GB->20GB)
Running multiple forward passes before export wonât cause any trouble.
A greedy search is implemented in the graph to generate 32 tokens. A lot of intermediate past_key_values are cached.
Very odd. It seems that the time and memory consumed to export a jit.ScriptModule are proportional to the loop size.
If this is true, it seems impossible to export a model with a decoding method into the ONNX computation graph.
class Model2(nn.Module):
def forward(self, x):
for i in range(2):
x *= x
return x
class Model32(nn.Module):
def forward(self, x):
for i in range(32):
x *= x
return x
I tried to convert a PyTorch model to ONNX, but encountered an OOM error. However, using inference directly can be successful. After adding âwith torch. reference_made()â before âtorch. onnx. exportâ, I was able to export the model as onnx without oom
I tried it but still got oom error print("\nStep 2: Preparing and running the ONNX export...") try: # The zsh: killederror is a classic Out-Of-Memory (OOM) error from the OS. # Your insight abouttorch.inference_mode()(you mentionedtorch.reference_made()`,
# by disabling all gradient calculations during loading and exporting.
with torch.inference_mode():
# --- Pre-load check and fix for empty index files ---
index_path = Path(pytorch_model_path) / "model.safetensors.index.json"
if index_path.exists() and index_path.stat().st_size == 0:
print(f"â ď¸ Found an empty index file at: {index_path}")
print(" This can cause loading errors. Removing it to proceed.")
os.remove(index_path)
print(" â Empty index file removed.")
# --- MPS DEBUGGING ---
# Forcing CPU to bypass any MPS-specific bugs.
device = "cpu"
print(f"Using device: {device}")
# Load the model and config ONCE to have better control over memory.
print("Loading model and config from disk...")
main_config = AutoConfig.from_pretrained(pytorch_model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
pytorch_model_path,
config=main_config,
trust_remote_code=True,
).to(device)
print("Model loaded.")
custom_onnx_configs = {
"decoder_model": CustomGemma3NMultimodalOnnxConfig(config=main_config, task="text-generation", use_past=False),
"decoder_with_past_model": CustomGemma3NMultimodalOnnxConfig(config=main_config, task="text-generation", use_past=True),
}
# Use the more direct `onnx_export_from_model` which takes a pre-loaded model object.
onnx_export_from_model(
model=model,
output=Path(onnx_output_path),
task="text-generation-with-past",
custom_onnx_configs=custom_onnx_configs,
fn_get_submodels=get_submodels,
opset=14,
do_validation=False,
device=device,
)
print("\nâ ONNX conversion process completed successfully!")
print(f" The exported model is saved in: {Path(onnx_output_path).resolve()}")
except Exception:
print(f"\nâ An error occurred during the ONNX conversion process.â); print(ââ FULL TRACEBACK ââ); traceback.print_exc(); print(ââ END OF TRACEBACK â")`