CUDA OOM when export a large model to ONNX

I got OOM when exporting a large model to ONNX. I wonder how Optimum handles this issue.

Here are my settings

  1. The code works for a smaller model with fewer parameters, so the error is due to the model size.
  2. Not able to export the model on the CPU because of fp16.
  3. The pure model takes 20GB of CUDA memory, and the total GPU capacity is 80G. ( It seems 10x memory will be consumed for exporting a small model. 2GB->20GB)
  4. Running multiple forward passes before export won鈥檛 cause any trouble.
  5. A greedy search is implemented in the graph to generate 32 tokens. A lot of intermediate past_key_values are cached.

Very odd. It seems that the time and memory consumed to export a jit.ScriptModule are proportional to the loop size.

If this is true, it seems impossible to export a model with a decoding method into the ONNX computation graph.

class Model2(nn.Module):
    def forward(self, x):
        for i in range(2):
            x *= x
        return x

class Model32(nn.Module):
    def forward(self, x):
        for i in range(32):
            x *= x
        return x

Thanks @in-certo , could it be linked to this issue? `torch.jit.trace` memory usage increase although forward is constant, and gets much slower than forward with model depth increase 路 Issue #93943 路 pytorch/pytorch 路 GitHub

I witnessed as well the memory usage increasing with the number of loops when using torch.jit.trace with stable diffusion.

Yes, exactly.

Here is another related issue: ONNX model file exported from Transformer decoder is too large 路 Issue #4319 路 onnx/onnx (