I got OOM when exporting a large model to ONNX. I wonder how Optimum handles this issue.
Here are my settings
The code works for a smaller model with fewer parameters, so the error is due to the model size.
Not able to export the model on the CPU because of fp16.
The pure model takes 20GB of CUDA memory, and the total GPU capacity is 80G. ( It seems 10x memory will be consumed for exporting a small model. 2GB->20GB)
Running multiple forward passes before export won’t cause any trouble.
A greedy search is implemented in the graph to generate 32 tokens. A lot of intermediate past_key_values are cached.
Very odd. It seems that the time and memory consumed to export a jit.ScriptModule are proportional to the loop size.
If this is true, it seems impossible to export a model with a decoding method into the ONNX computation graph.
class Model2(nn.Module):
def forward(self, x):
for i in range(2):
x *= x
return x
class Model32(nn.Module):
def forward(self, x):
for i in range(32):
x *= x
return x
I tried to convert a PyTorch model to ONNX, but encountered an OOM error. However, using inference directly can be successful. After adding ‘with torch. reference_made()’ before ‘torch. onnx. export’, I was able to export the model as onnx without oom